summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar comex2023-06-19 18:17:43 -0700
committerGravatar comex2023-06-25 12:53:31 -0700
commit8e703e08dfcf735a08df2ceff6a05221b7cc981f (patch)
tree771ebe71883ff9e179156f2b38b21b05070d7667 /src
parentMerge pull request #10825 from 8bitDream/vcpkg-zlib (diff)
downloadyuzu-8e703e08dfcf735a08df2ceff6a05221b7cc981f.tar.gz
yuzu-8e703e08dfcf735a08df2ceff6a05221b7cc981f.tar.xz
yuzu-8e703e08dfcf735a08df2ceff6a05221b7cc981f.zip
Implement SSL service
This implements some missing network APIs including a large chunk of the SSL service, enough for Mario Maker (with an appropriate mod applied) to connect to the fan server [Open Course World](https://opencourse.world/). Connecting to first-party servers is out of scope of this PR and is a minefield I'd rather not step into. ## TLS TLS is implemented with multiple backends depending on the system's 'native' TLS library. Currently there are two backends: Schannel for Windows, and OpenSSL for Linux. (In reality Linux is a bit of a free-for-all where there's no one 'native' library, but OpenSSL is the closest it gets.) On macOS the 'native' library is SecureTransport but that isn't implemented in this PR. (Instead, all non-Windows OSes will use OpenSSL unless disabled with `-DENABLE_OPENSSL=OFF`.) Why have multiple backends instead of just using a single library, especially given that Yuzu already embeds mbedtls for cryptographic algorithms? Well, I tried implementing this on mbedtls first, but the problem is TLS policies - mainly trusted certificate policies, and to a lesser extent trusted algorithms, SSL versions, etc. ...In practice, the chance that someone is going to conduct a man-in-the-middle attack on a third-party game server is pretty low, but I'm a security nerd so I like to do the right security things. My base assumption is that we want to use the host system's TLS policies. An alternative would be to more closely emulate the Switch's TLS implementation (which is based on NSS). But for one thing, I don't feel like reverse engineering it. And I'd argue that for third-party servers such as Open Course World, it's theoretically preferable to use the system's policies rather than the Switch's, for two reasons 1. Someday the Switch will stop being updated, and the trusted cert list, algorithms, etc. will start to go stale, but users will still want to connect to third-party servers, and there's no reason they shouldn't have up-to-date security when doing so. At that point, homebrew users on actual hardware may patch the TLS implementation, but for emulators it's simpler to just use the host's stack. 2. Also, it's good to respect any custom certificate policies the user may have added systemwide. For example, they may have added custom trusted CAs in order to use TLS debugging tools or pass through corporate MitM middleboxes. Or they may have removed some CAs that are normally trusted out of paranoia. Note that this policy wouldn't work as-is for connecting to first-party servers, because some of them serve certificates based on Nintendo's own CA rather than a publicly trusted one. However, this could probably be solved easily by using appropriate APIs to adding Nintendo's CA as an alternate trusted cert for Yuzu's connections. That is not implemented in this PR because, again, first-party servers are out of scope. (If anything I'd rather have an option to _block_ connections to Nintendo servers, but that's not implemented here.) To use the host's TLS policies, there are three theoretical options: a) Import the host's trusted certificate list into a cross-platform TLS library (presumably mbedtls). b) Use the native TLS library to verify certificates but use a cross-platform TLS library for everything else. c) Use the native TLS library for everything. Two problems with option a). First, importing the trusted certificate list at minimum requires a bunch of platform-specific code, which mbedtls does not have built in. Interestingly, OpenSSL recently gained the ability to import the Windows certificate trust store... but that leads to the second problem, which is that a list of trusted certificates is [not expressive enough](https://bugs.archlinux.org/task/41909) to express a modern certificate trust policy. For example, Windows has the concept of [explicitly distrusted certificates](https://learn.microsoft.com/en-us/previous-versions/windows/it-pro/windows-server-2012-r2-and-2012/dn265983(v=ws.11)), and macOS requires Certificate Transparency validation for some certificates with complex rules for when it's required. Option b) (using native library just to verify certs) is probably feasible, but it would miss aspects of TLS policy other than trusted certs (like allowed algorithms), and in any case it might well require writing more code, not less, compared to using the native library for everything. So I ended up at option c), using the native library for everything. What I'd *really* prefer would be to use a third-party library that does option c) for me. Rust has a good library for this, [native-tls](https://docs.rs/native-tls/latest/native_tls/). I did search, but I couldn't find a good option in the C or C++ ecosystem, at least not any that wasn't part of some much larger framework. I was surprised - isn't this a pretty common use case? Well, many applications only need TLS for HTTPS, and they can use libcurl, which has a TLS abstraction layer internally but doesn't expose it. Other applications only support a single TLS library, or use one of the aforementioned larger frameworks, or are platform-specific to begin with, or of course are written in a non-C/C++ language, most of which have some canonical choice for TLS. But there are also many applications that have a set of TLS backends just like this; it's just that nobody has gone ahead and abstracted the pattern into a library, at least not a widespread one. Amusingly, there is one TLS abstraction layer that Yuzu already bundles: the one in ffmpeg. But it is missing some features that would be needed to use it here (like reusing an existing socket rather than managing the socket itself). Though, that does mean that the wiki's build instructions for Linux (and macOS for some reason?) already recommend installing OpenSSL, so no need to update those. ## Other APIs implemented - Sockets: - GetSockOpt(`SO_ERROR`) - SetSockOpt(`SO_NOSIGPIPE`) (stub, I have no idea what this does on Switch) - `DuplicateSocket` (because the SSL sysmodule calls it internally) - More `PollEvents` values - NSD: - `Resolve` and `ResolveEx` (stub, good enough for Open Course World and probably most third-party servers, but not first-party) - SFDNSRES: - `GetHostByNameRequest` and `GetHostByNameRequestWithOptions` - `ResolverSetOptionRequest` (stub) ## Fixes - Parts of the socket code were previously allocating a `sockaddr` object on the stack when calling functions that take a `sockaddr*` (e.g. `accept`). This might seem like the right thing to do to avoid illegal aliasing, but in fact `sockaddr` is not guaranteed to be large enough to hold any particular type of address, only the header. This worked in practice because in practice `sockaddr` is the same size as `sockaddr_in`, but it's not how the API is meant to be used. I changed this to allocate an `sockaddr_in` on the stack and `reinterpret_cast` it. I could try to do something cleverer with `aligned_storage`, but casting is the idiomatic way to use these particular APIs, so it's really the system's responsibility to avoid any aliasing issues. - I rewrote most of the `GetAddrInfoRequest[WithOptions]` implementation. The old implementation invoked the host's getaddrinfo directly from sfdnsres.cpp, and directly passed through the host's socket type, protocol, etc. values rather than looking up the corresponding constants on the Switch. To be fair, these constants don't tend to actually vary across systems, but still... I added a wrapper for `getaddrinfo` in `internal_network/network.cpp` similar to the ones for other socket APIs, and changed the `GetAddrInfoRequest` implementation to use it. While I was at it, I rewrote the serialization to use the same approach I used to implement `GetHostByNameRequest`, because it reduces the number of size calculations. While doing so I removed `AF_INET6` support because the Switch doesn't support IPv6; it might be nice to support IPv6 anyway, but that would have to apply to all of the socket APIs. I also corrected the IPC wrappers for `GetAddrInfoRequest` and `GetAddrInfoRequestWithOptions` based on reverse engineering and hardware testing. Every call to `GetAddrInfoRequestWithOptions` returns *four* different error codes (IPC status, getaddrinfo error code, netdb error code, and errno), and `GetAddrInfoRequest` returns three of those but in a different order, and it doesn't really matter but the existing implementation was a bit off, as I discovered while testing `GetHostByNameRequest`. - The new serialization code is based on two simple helper functions: ```cpp template <typename T> static void Append(std::vector<u8>& vec, T t); void AppendNulTerminated(std::vector<u8>& vec, std::string_view str); ``` I was thinking there must be existing functions somewhere that assist with serialization/deserialization of binary data, but all I could find was the helper methods in `IOFile` and `HLERequestContext`, not anything that could be used with a generic byte buffer. If I'm not missing something, then maybe I should move the above functions to a new header in `common`... right now they're just sitting in `sfdnsres.cpp` where they're used. - Not a fix, but `SocketBase::Recv`/`Send` is changed to use `std::span<u8>` rather than `std::vector<u8>&` to avoid needing to copy the data to/from a vector when those methods are called from the TLS implementation.
Diffstat (limited to 'src')
-rw-r--r--src/common/socket_types.h16
-rw-r--r--src/core/CMakeLists.txt14
-rw-r--r--src/core/hle/service/sockets/bsd.cpp107
-rw-r--r--src/core/hle/service/sockets/bsd.h13
-rw-r--r--src/core/hle/service/sockets/nsd.cpp58
-rw-r--r--src/core/hle/service/sockets/nsd.h4
-rw-r--r--src/core/hle/service/sockets/sfdnsres.cpp345
-rw-r--r--src/core/hle/service/sockets/sfdnsres.h3
-rw-r--r--src/core/hle/service/sockets/sockets.h33
-rw-r--r--src/core/hle/service/sockets/sockets_translate.cpp114
-rw-r--r--src/core/hle/service/sockets/sockets_translate.h17
-rw-r--r--src/core/hle/service/ssl/ssl.cpp349
-rw-r--r--src/core/hle/service/ssl/ssl_backend.h44
-rw-r--r--src/core/hle/service/ssl/ssl_backend_none.cpp15
-rw-r--r--src/core/hle/service/ssl/ssl_backend_openssl.cpp342
-rw-r--r--src/core/hle/service/ssl/ssl_backend_schannel.cpp529
-rw-r--r--src/core/internal_network/network.cpp274
-rw-r--r--src/core/internal_network/network.h34
-rw-r--r--src/core/internal_network/socket_proxy.cpp22
-rw-r--r--src/core/internal_network/socket_proxy.h8
-rw-r--r--src/core/internal_network/sockets.h16
21 files changed, 2080 insertions, 277 deletions
diff --git a/src/common/socket_types.h b/src/common/socket_types.h
index 0a801a443..18ad6ac95 100644
--- a/src/common/socket_types.h
+++ b/src/common/socket_types.h
@@ -5,15 +5,19 @@
5 5
6#include "common/common_types.h" 6#include "common/common_types.h"
7 7
8#include <optional>
9
8namespace Network { 10namespace Network {
9 11
10/// Address families 12/// Address families
11enum class Domain : u8 { 13enum class Domain : u8 {
12 INET, ///< Address family for IPv4 14 Unspecified, ///< Can be 0 in getaddrinfo hints
15 INET, ///< Address family for IPv4
13}; 16};
14 17
15/// Socket types 18/// Socket types
16enum class Type { 19enum class Type {
20 Unspecified, ///< Can be 0 in getaddrinfo hints
17 STREAM, 21 STREAM,
18 DGRAM, 22 DGRAM,
19 RAW, 23 RAW,
@@ -22,6 +26,7 @@ enum class Type {
22 26
23/// Protocol values for sockets 27/// Protocol values for sockets
24enum class Protocol : u8 { 28enum class Protocol : u8 {
29 Unspecified,
25 ICMP, 30 ICMP,
26 TCP, 31 TCP,
27 UDP, 32 UDP,
@@ -48,4 +53,13 @@ constexpr u32 FLAG_MSG_PEEK = 0x2;
48constexpr u32 FLAG_MSG_DONTWAIT = 0x80; 53constexpr u32 FLAG_MSG_DONTWAIT = 0x80;
49constexpr u32 FLAG_O_NONBLOCK = 0x800; 54constexpr u32 FLAG_O_NONBLOCK = 0x800;
50 55
56/// Cross-platform addrinfo structure
57struct AddrInfo {
58 Domain family;
59 Type socket_type;
60 Protocol protocol;
61 SockAddrIn addr;
62 std::optional<std::string> canon_name;
63};
64
51} // namespace Network 65} // namespace Network
diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt
index 227c431bc..d95d2fe01 100644
--- a/src/core/CMakeLists.txt
+++ b/src/core/CMakeLists.txt
@@ -723,6 +723,7 @@ add_library(core STATIC
723 hle/service/spl/spl_types.h 723 hle/service/spl/spl_types.h
724 hle/service/ssl/ssl.cpp 724 hle/service/ssl/ssl.cpp
725 hle/service/ssl/ssl.h 725 hle/service/ssl/ssl.h
726 hle/service/ssl/ssl_backend.h
726 hle/service/time/clock_types.h 727 hle/service/time/clock_types.h
727 hle/service/time/ephemeral_network_system_clock_context_writer.h 728 hle/service/time/ephemeral_network_system_clock_context_writer.h
728 hle/service/time/ephemeral_network_system_clock_core.h 729 hle/service/time/ephemeral_network_system_clock_core.h
@@ -864,6 +865,19 @@ if (ARCHITECTURE_x86_64 OR ARCHITECTURE_arm64)
864 target_link_libraries(core PRIVATE dynarmic::dynarmic) 865 target_link_libraries(core PRIVATE dynarmic::dynarmic)
865endif() 866endif()
866 867
868if(ENABLE_OPENSSL)
869 target_sources(core PRIVATE
870 hle/service/ssl/ssl_backend_openssl.cpp)
871 target_link_libraries(core PRIVATE OpenSSL::SSL)
872elseif (WIN32)
873 target_sources(core PRIVATE
874 hle/service/ssl/ssl_backend_schannel.cpp)
875 target_link_libraries(core PRIVATE Secur32)
876else()
877 target_sources(core PRIVATE
878 hle/service/ssl/ssl_backend_none.cpp)
879endif()
880
867if (YUZU_USE_PRECOMPILED_HEADERS) 881if (YUZU_USE_PRECOMPILED_HEADERS)
868 target_precompile_headers(core PRIVATE precompiled_headers.h) 882 target_precompile_headers(core PRIVATE precompiled_headers.h)
869endif() 883endif()
diff --git a/src/core/hle/service/sockets/bsd.cpp b/src/core/hle/service/sockets/bsd.cpp
index bce45d321..6677689dc 100644
--- a/src/core/hle/service/sockets/bsd.cpp
+++ b/src/core/hle/service/sockets/bsd.cpp
@@ -20,6 +20,9 @@
20#include "core/internal_network/sockets.h" 20#include "core/internal_network/sockets.h"
21#include "network/network.h" 21#include "network/network.h"
22 22
23using Common::Expected;
24using Common::Unexpected;
25
23namespace Service::Sockets { 26namespace Service::Sockets {
24 27
25namespace { 28namespace {
@@ -265,16 +268,19 @@ void BSD::GetSockOpt(HLERequestContext& ctx) {
265 const u32 level = rp.Pop<u32>(); 268 const u32 level = rp.Pop<u32>();
266 const auto optname = static_cast<OptName>(rp.Pop<u32>()); 269 const auto optname = static_cast<OptName>(rp.Pop<u32>());
267 270
268 LOG_WARNING(Service, "(STUBBED) called. fd={} level={} optname=0x{:x}", fd, level, optname);
269
270 std::vector<u8> optval(ctx.GetWriteBufferSize()); 271 std::vector<u8> optval(ctx.GetWriteBufferSize());
271 272
273 LOG_WARNING(Service, "called. fd={} level={} optname=0x{:x} len=0x{:x}", fd, level, optname,
274 optval.size());
275
276 Errno err = GetSockOptImpl(fd, level, optname, optval);
277
272 ctx.WriteBuffer(optval); 278 ctx.WriteBuffer(optval);
273 279
274 IPC::ResponseBuilder rb{ctx, 5}; 280 IPC::ResponseBuilder rb{ctx, 5};
275 rb.Push(ResultSuccess); 281 rb.Push(ResultSuccess);
276 rb.Push<s32>(-1); 282 rb.Push<s32>(err == Errno::SUCCESS ? 0 : -1);
277 rb.PushEnum(Errno::NOTCONN); 283 rb.PushEnum(err);
278 rb.Push<u32>(static_cast<u32>(optval.size())); 284 rb.Push<u32>(static_cast<u32>(optval.size()));
279} 285}
280 286
@@ -436,6 +442,18 @@ void BSD::Close(HLERequestContext& ctx) {
436 BuildErrnoResponse(ctx, CloseImpl(fd)); 442 BuildErrnoResponse(ctx, CloseImpl(fd));
437} 443}
438 444
445void BSD::DuplicateSocket(HLERequestContext& ctx) {
446 IPC::RequestParser rp{ctx};
447 const s32 fd = rp.Pop<s32>();
448 [[maybe_unused]] const u64 unused = rp.Pop<u64>();
449
450 Common::Expected<s32, Errno> res = DuplicateSocketImpl(fd);
451 IPC::ResponseBuilder rb{ctx, 4};
452 rb.Push(ResultSuccess);
453 rb.Push(res.value_or(0)); // ret
454 rb.Push(res ? 0 : static_cast<s32>(res.error())); // bsd errno
455}
456
439void BSD::EventFd(HLERequestContext& ctx) { 457void BSD::EventFd(HLERequestContext& ctx) {
440 IPC::RequestParser rp{ctx}; 458 IPC::RequestParser rp{ctx};
441 const u64 initval = rp.Pop<u64>(); 459 const u64 initval = rp.Pop<u64>();
@@ -477,12 +495,12 @@ std::pair<s32, Errno> BSD::SocketImpl(Domain domain, Type type, Protocol protoco
477 495
478 auto room_member = room_network.GetRoomMember().lock(); 496 auto room_member = room_network.GetRoomMember().lock();
479 if (room_member && room_member->IsConnected()) { 497 if (room_member && room_member->IsConnected()) {
480 descriptor.socket = std::make_unique<Network::ProxySocket>(room_network); 498 descriptor.socket = std::make_shared<Network::ProxySocket>(room_network);
481 } else { 499 } else {
482 descriptor.socket = std::make_unique<Network::Socket>(); 500 descriptor.socket = std::make_shared<Network::Socket>();
483 } 501 }
484 502
485 descriptor.socket->Initialize(Translate(domain), Translate(type), Translate(type, protocol)); 503 descriptor.socket->Initialize(Translate(domain), Translate(type), Translate(protocol));
486 descriptor.is_connection_based = IsConnectionBased(type); 504 descriptor.is_connection_based = IsConnectionBased(type);
487 505
488 return {fd, Errno::SUCCESS}; 506 return {fd, Errno::SUCCESS};
@@ -538,7 +556,7 @@ std::pair<s32, Errno> BSD::PollImpl(std::vector<u8>& write_buffer, std::span<con
538 std::transform(fds.begin(), fds.end(), host_pollfds.begin(), [this](PollFD pollfd) { 556 std::transform(fds.begin(), fds.end(), host_pollfds.begin(), [this](PollFD pollfd) {
539 Network::PollFD result; 557 Network::PollFD result;
540 result.socket = file_descriptors[pollfd.fd]->socket.get(); 558 result.socket = file_descriptors[pollfd.fd]->socket.get();
541 result.events = TranslatePollEventsToHost(pollfd.events); 559 result.events = Translate(pollfd.events);
542 result.revents = Network::PollEvents{}; 560 result.revents = Network::PollEvents{};
543 return result; 561 return result;
544 }); 562 });
@@ -547,7 +565,7 @@ std::pair<s32, Errno> BSD::PollImpl(std::vector<u8>& write_buffer, std::span<con
547 565
548 const size_t num = host_pollfds.size(); 566 const size_t num = host_pollfds.size();
549 for (size_t i = 0; i < num; ++i) { 567 for (size_t i = 0; i < num; ++i) {
550 fds[i].revents = TranslatePollEventsToGuest(host_pollfds[i].revents); 568 fds[i].revents = Translate(host_pollfds[i].revents);
551 } 569 }
552 std::memcpy(write_buffer.data(), fds.data(), length); 570 std::memcpy(write_buffer.data(), fds.data(), length);
553 571
@@ -617,7 +635,8 @@ Errno BSD::GetPeerNameImpl(s32 fd, std::vector<u8>& write_buffer) {
617 } 635 }
618 const SockAddrIn guest_addrin = Translate(addr_in); 636 const SockAddrIn guest_addrin = Translate(addr_in);
619 637
620 ASSERT(write_buffer.size() == sizeof(guest_addrin)); 638 ASSERT(write_buffer.size() >= sizeof(guest_addrin));
639 write_buffer.resize(sizeof(guest_addrin));
621 std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin)); 640 std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin));
622 return Translate(bsd_errno); 641 return Translate(bsd_errno);
623} 642}
@@ -633,7 +652,8 @@ Errno BSD::GetSockNameImpl(s32 fd, std::vector<u8>& write_buffer) {
633 } 652 }
634 const SockAddrIn guest_addrin = Translate(addr_in); 653 const SockAddrIn guest_addrin = Translate(addr_in);
635 654
636 ASSERT(write_buffer.size() == sizeof(guest_addrin)); 655 ASSERT(write_buffer.size() >= sizeof(guest_addrin));
656 write_buffer.resize(sizeof(guest_addrin));
637 std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin)); 657 std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin));
638 return Translate(bsd_errno); 658 return Translate(bsd_errno);
639} 659}
@@ -671,13 +691,47 @@ std::pair<s32, Errno> BSD::FcntlImpl(s32 fd, FcntlCmd cmd, s32 arg) {
671 } 691 }
672} 692}
673 693
674Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval) { 694Errno BSD::GetSockOptImpl(s32 fd, u32 level, OptName optname, std::vector<u8>& optval) {
675 UNIMPLEMENTED_IF(level != 0xffff); // SOL_SOCKET 695 if (!IsFileDescriptorValid(fd)) {
696 return Errno::BADF;
697 }
698
699 if (level != static_cast<u32>(SocketLevel::SOCKET)) {
700 UNIMPLEMENTED_MSG("Unknown getsockopt level");
701 return Errno::SUCCESS;
702 }
703
704 Network::SocketBase* const socket = file_descriptors[fd]->socket.get();
705
706 switch (optname) {
707 case OptName::ERROR_: {
708 auto [pending_err, getsockopt_err] = socket->GetPendingError();
709 if (getsockopt_err == Network::Errno::SUCCESS) {
710 Errno translated_pending_err = Translate(pending_err);
711 ASSERT_OR_EXECUTE_MSG(
712 optval.size() == sizeof(Errno), { return Errno::INVAL; },
713 "Incorrect getsockopt option size");
714 optval.resize(sizeof(Errno));
715 memcpy(optval.data(), &translated_pending_err, sizeof(Errno));
716 }
717 return Translate(getsockopt_err);
718 }
719 default:
720 UNIMPLEMENTED_MSG("Unimplemented optname={}", optname);
721 return Errno::SUCCESS;
722 }
723}
676 724
725Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval) {
677 if (!IsFileDescriptorValid(fd)) { 726 if (!IsFileDescriptorValid(fd)) {
678 return Errno::BADF; 727 return Errno::BADF;
679 } 728 }
680 729
730 if (level != static_cast<u32>(SocketLevel::SOCKET)) {
731 UNIMPLEMENTED_MSG("Unknown setsockopt level");
732 return Errno::SUCCESS;
733 }
734
681 Network::SocketBase* const socket = file_descriptors[fd]->socket.get(); 735 Network::SocketBase* const socket = file_descriptors[fd]->socket.get();
682 736
683 if (optname == OptName::LINGER) { 737 if (optname == OptName::LINGER) {
@@ -711,6 +765,9 @@ Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, con
711 return Translate(socket->SetSndTimeo(value)); 765 return Translate(socket->SetSndTimeo(value));
712 case OptName::RCVTIMEO: 766 case OptName::RCVTIMEO:
713 return Translate(socket->SetRcvTimeo(value)); 767 return Translate(socket->SetRcvTimeo(value));
768 case OptName::NOSIGPIPE:
769 LOG_WARNING(Service, "(STUBBED) setting NOSIGPIPE to {}", value);
770 return Errno::SUCCESS;
714 default: 771 default:
715 UNIMPLEMENTED_MSG("Unimplemented optname={}", optname); 772 UNIMPLEMENTED_MSG("Unimplemented optname={}", optname);
716 return Errno::SUCCESS; 773 return Errno::SUCCESS;
@@ -841,6 +898,28 @@ Errno BSD::CloseImpl(s32 fd) {
841 return bsd_errno; 898 return bsd_errno;
842} 899}
843 900
901Expected<s32, Errno> BSD::DuplicateSocketImpl(s32 fd) {
902 if (!IsFileDescriptorValid(fd)) {
903 return Unexpected(Errno::BADF);
904 }
905
906 const s32 new_fd = FindFreeFileDescriptorHandle();
907 if (new_fd < 0) {
908 LOG_ERROR(Service, "No more file descriptors available");
909 return Unexpected(Errno::MFILE);
910 }
911
912 file_descriptors[new_fd] = file_descriptors[fd];
913 return new_fd;
914}
915
916std::optional<std::shared_ptr<Network::SocketBase>> BSD::GetSocket(s32 fd) {
917 if (!IsFileDescriptorValid(fd)) {
918 return std::nullopt;
919 }
920 return file_descriptors[fd]->socket;
921}
922
844s32 BSD::FindFreeFileDescriptorHandle() noexcept { 923s32 BSD::FindFreeFileDescriptorHandle() noexcept {
845 for (s32 fd = 0; fd < static_cast<s32>(file_descriptors.size()); ++fd) { 924 for (s32 fd = 0; fd < static_cast<s32>(file_descriptors.size()); ++fd) {
846 if (!file_descriptors[fd]) { 925 if (!file_descriptors[fd]) {
@@ -911,7 +990,7 @@ BSD::BSD(Core::System& system_, const char* name)
911 {24, &BSD::Write, "Write"}, 990 {24, &BSD::Write, "Write"},
912 {25, &BSD::Read, "Read"}, 991 {25, &BSD::Read, "Read"},
913 {26, &BSD::Close, "Close"}, 992 {26, &BSD::Close, "Close"},
914 {27, nullptr, "DuplicateSocket"}, 993 {27, &BSD::DuplicateSocket, "DuplicateSocket"},
915 {28, nullptr, "GetResourceStatistics"}, 994 {28, nullptr, "GetResourceStatistics"},
916 {29, nullptr, "RecvMMsg"}, 995 {29, nullptr, "RecvMMsg"},
917 {30, nullptr, "SendMMsg"}, 996 {30, nullptr, "SendMMsg"},
diff --git a/src/core/hle/service/sockets/bsd.h b/src/core/hle/service/sockets/bsd.h
index 30ae9c140..430edb97c 100644
--- a/src/core/hle/service/sockets/bsd.h
+++ b/src/core/hle/service/sockets/bsd.h
@@ -8,6 +8,7 @@
8#include <vector> 8#include <vector>
9 9
10#include "common/common_types.h" 10#include "common/common_types.h"
11#include "common/expected.h"
11#include "common/socket_types.h" 12#include "common/socket_types.h"
12#include "core/hle/service/service.h" 13#include "core/hle/service/service.h"
13#include "core/hle/service/sockets/sockets.h" 14#include "core/hle/service/sockets/sockets.h"
@@ -29,12 +30,19 @@ public:
29 explicit BSD(Core::System& system_, const char* name); 30 explicit BSD(Core::System& system_, const char* name);
30 ~BSD() override; 31 ~BSD() override;
31 32
33 // These methods are called from SSL; the first two are also called from
34 // this class for the corresponding IPC methods.
35 // On the real device, the SSL service makes IPC calls to this service.
36 Common::Expected<s32, Errno> DuplicateSocketImpl(s32 fd);
37 Errno CloseImpl(s32 fd);
38 std::optional<std::shared_ptr<Network::SocketBase>> GetSocket(s32 fd);
39
32private: 40private:
33 /// Maximum number of file descriptors 41 /// Maximum number of file descriptors
34 static constexpr size_t MAX_FD = 128; 42 static constexpr size_t MAX_FD = 128;
35 43
36 struct FileDescriptor { 44 struct FileDescriptor {
37 std::unique_ptr<Network::SocketBase> socket; 45 std::shared_ptr<Network::SocketBase> socket;
38 s32 flags = 0; 46 s32 flags = 0;
39 bool is_connection_based = false; 47 bool is_connection_based = false;
40 }; 48 };
@@ -138,6 +146,7 @@ private:
138 void Write(HLERequestContext& ctx); 146 void Write(HLERequestContext& ctx);
139 void Read(HLERequestContext& ctx); 147 void Read(HLERequestContext& ctx);
140 void Close(HLERequestContext& ctx); 148 void Close(HLERequestContext& ctx);
149 void DuplicateSocket(HLERequestContext& ctx);
141 void EventFd(HLERequestContext& ctx); 150 void EventFd(HLERequestContext& ctx);
142 151
143 template <typename Work> 152 template <typename Work>
@@ -153,6 +162,7 @@ private:
153 Errno GetSockNameImpl(s32 fd, std::vector<u8>& write_buffer); 162 Errno GetSockNameImpl(s32 fd, std::vector<u8>& write_buffer);
154 Errno ListenImpl(s32 fd, s32 backlog); 163 Errno ListenImpl(s32 fd, s32 backlog);
155 std::pair<s32, Errno> FcntlImpl(s32 fd, FcntlCmd cmd, s32 arg); 164 std::pair<s32, Errno> FcntlImpl(s32 fd, FcntlCmd cmd, s32 arg);
165 Errno GetSockOptImpl(s32 fd, u32 level, OptName optname, std::vector<u8>& optval);
156 Errno SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval); 166 Errno SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval);
157 Errno ShutdownImpl(s32 fd, s32 how); 167 Errno ShutdownImpl(s32 fd, s32 how);
158 std::pair<s32, Errno> RecvImpl(s32 fd, u32 flags, std::vector<u8>& message); 168 std::pair<s32, Errno> RecvImpl(s32 fd, u32 flags, std::vector<u8>& message);
@@ -161,7 +171,6 @@ private:
161 std::pair<s32, Errno> SendImpl(s32 fd, u32 flags, std::span<const u8> message); 171 std::pair<s32, Errno> SendImpl(s32 fd, u32 flags, std::span<const u8> message);
162 std::pair<s32, Errno> SendToImpl(s32 fd, u32 flags, std::span<const u8> message, 172 std::pair<s32, Errno> SendToImpl(s32 fd, u32 flags, std::span<const u8> message,
163 std::span<const u8> addr); 173 std::span<const u8> addr);
164 Errno CloseImpl(s32 fd);
165 174
166 s32 FindFreeFileDescriptorHandle() noexcept; 175 s32 FindFreeFileDescriptorHandle() noexcept;
167 bool IsFileDescriptorValid(s32 fd) const noexcept; 176 bool IsFileDescriptorValid(s32 fd) const noexcept;
diff --git a/src/core/hle/service/sockets/nsd.cpp b/src/core/hle/service/sockets/nsd.cpp
index 6491a73be..22c3a31a0 100644
--- a/src/core/hle/service/sockets/nsd.cpp
+++ b/src/core/hle/service/sockets/nsd.cpp
@@ -1,10 +1,15 @@
1// SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project 1// SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later 2// SPDX-License-Identifier: GPL-2.0-or-later
3 3
4#include "core/hle/service/ipc_helpers.h"
4#include "core/hle/service/sockets/nsd.h" 5#include "core/hle/service/sockets/nsd.h"
5 6
7#include "common/string_util.h"
8
6namespace Service::Sockets { 9namespace Service::Sockets {
7 10
11constexpr Result ResultOverflow{ErrorModule::NSD, 6};
12
8NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, name} { 13NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, name} {
9 // clang-format off 14 // clang-format off
10 static const FunctionInfo functions[] = { 15 static const FunctionInfo functions[] = {
@@ -15,8 +20,8 @@ NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, na
15 {13, nullptr, "DeleteSettings"}, 20 {13, nullptr, "DeleteSettings"},
16 {14, nullptr, "ImportSettings"}, 21 {14, nullptr, "ImportSettings"},
17 {15, nullptr, "SetChangeEnvironmentIdentifierDisabled"}, 22 {15, nullptr, "SetChangeEnvironmentIdentifierDisabled"},
18 {20, nullptr, "Resolve"}, 23 {20, &NSD::Resolve, "Resolve"},
19 {21, nullptr, "ResolveEx"}, 24 {21, &NSD::ResolveEx, "ResolveEx"},
20 {30, nullptr, "GetNasServiceSetting"}, 25 {30, nullptr, "GetNasServiceSetting"},
21 {31, nullptr, "GetNasServiceSettingEx"}, 26 {31, nullptr, "GetNasServiceSettingEx"},
22 {40, nullptr, "GetNasRequestFqdn"}, 27 {40, nullptr, "GetNasRequestFqdn"},
@@ -40,6 +45,55 @@ NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, na
40 RegisterHandlers(functions); 45 RegisterHandlers(functions);
41} 46}
42 47
48static ResultVal<std::string> ResolveImpl(const std::string& fqdn_in) {
49 // The real implementation makes various substitutions.
50 // For now we just return the string as-is, which is good enough when not
51 // connecting to real Nintendo servers.
52 LOG_WARNING(Service, "(STUBBED) called({})", fqdn_in);
53 return fqdn_in;
54}
55
56static Result ResolveCommon(const std::string& fqdn_in, std::array<char, 0x100>& fqdn_out) {
57 const auto res = ResolveImpl(fqdn_in);
58 if (res.Failed()) {
59 return res.Code();
60 }
61 if (res->size() >= fqdn_out.size()) {
62 return ResultOverflow;
63 }
64 std::memcpy(fqdn_out.data(), res->c_str(), res->size() + 1);
65 return ResultSuccess;
66}
67
68void NSD::Resolve(HLERequestContext& ctx) {
69 const std::string fqdn_in = Common::StringFromBuffer(ctx.ReadBuffer(0));
70
71 std::array<char, 0x100> fqdn_out{};
72 Result res = ResolveCommon(fqdn_in, fqdn_out);
73
74 ctx.WriteBuffer(fqdn_out);
75 IPC::ResponseBuilder rb{ctx, 2};
76 rb.Push(res);
77}
78
79void NSD::ResolveEx(HLERequestContext& ctx) {
80 const std::string fqdn_in = Common::StringFromBuffer(ctx.ReadBuffer(0));
81
82 std::array<char, 0x100> fqdn_out;
83 Result res = ResolveCommon(fqdn_in, fqdn_out);
84
85 if (res.IsError()) {
86 IPC::ResponseBuilder rb{ctx, 2};
87 rb.Push(res);
88 return;
89 }
90
91 ctx.WriteBuffer(fqdn_out);
92 IPC::ResponseBuilder rb{ctx, 4};
93 rb.Push(ResultSuccess);
94 rb.Push(ResultSuccess);
95}
96
43NSD::~NSD() = default; 97NSD::~NSD() = default;
44 98
45} // namespace Service::Sockets 99} // namespace Service::Sockets
diff --git a/src/core/hle/service/sockets/nsd.h b/src/core/hle/service/sockets/nsd.h
index 5cc12b855..a7379a8a9 100644
--- a/src/core/hle/service/sockets/nsd.h
+++ b/src/core/hle/service/sockets/nsd.h
@@ -15,6 +15,10 @@ class NSD final : public ServiceFramework<NSD> {
15public: 15public:
16 explicit NSD(Core::System& system_, const char* name); 16 explicit NSD(Core::System& system_, const char* name);
17 ~NSD() override; 17 ~NSD() override;
18
19private:
20 void Resolve(HLERequestContext& ctx);
21 void ResolveEx(HLERequestContext& ctx);
18}; 22};
19 23
20} // namespace Service::Sockets 24} // namespace Service::Sockets
diff --git a/src/core/hle/service/sockets/sfdnsres.cpp b/src/core/hle/service/sockets/sfdnsres.cpp
index 132dd5797..1196fb86c 100644
--- a/src/core/hle/service/sockets/sfdnsres.cpp
+++ b/src/core/hle/service/sockets/sfdnsres.cpp
@@ -10,27 +10,18 @@
10#include "core/core.h" 10#include "core/core.h"
11#include "core/hle/service/ipc_helpers.h" 11#include "core/hle/service/ipc_helpers.h"
12#include "core/hle/service/sockets/sfdnsres.h" 12#include "core/hle/service/sockets/sfdnsres.h"
13#include "core/hle/service/sockets/sockets.h"
14#include "core/hle/service/sockets/sockets_translate.h"
15#include "core/internal_network/network.h"
13#include "core/memory.h" 16#include "core/memory.h"
14 17
15#ifdef _WIN32
16#include <ws2tcpip.h>
17#elif YUZU_UNIX
18#include <arpa/inet.h>
19#include <netdb.h>
20#include <netinet/in.h>
21#include <sys/socket.h>
22#ifndef EAI_NODATA
23#define EAI_NODATA EAI_NONAME
24#endif
25#endif
26
27namespace Service::Sockets { 18namespace Service::Sockets {
28 19
29SFDNSRES::SFDNSRES(Core::System& system_) : ServiceFramework{system_, "sfdnsres"} { 20SFDNSRES::SFDNSRES(Core::System& system_) : ServiceFramework{system_, "sfdnsres"} {
30 static const FunctionInfo functions[] = { 21 static const FunctionInfo functions[] = {
31 {0, nullptr, "SetDnsAddressesPrivateRequest"}, 22 {0, nullptr, "SetDnsAddressesPrivateRequest"},
32 {1, nullptr, "GetDnsAddressPrivateRequest"}, 23 {1, nullptr, "GetDnsAddressPrivateRequest"},
33 {2, nullptr, "GetHostByNameRequest"}, 24 {2, &SFDNSRES::GetHostByNameRequest, "GetHostByNameRequest"},
34 {3, nullptr, "GetHostByAddrRequest"}, 25 {3, nullptr, "GetHostByAddrRequest"},
35 {4, nullptr, "GetHostStringErrorRequest"}, 26 {4, nullptr, "GetHostStringErrorRequest"},
36 {5, nullptr, "GetGaiStringErrorRequest"}, 27 {5, nullptr, "GetGaiStringErrorRequest"},
@@ -38,11 +29,11 @@ SFDNSRES::SFDNSRES(Core::System& system_) : ServiceFramework{system_, "sfdnsres"
38 {7, nullptr, "GetNameInfoRequest"}, 29 {7, nullptr, "GetNameInfoRequest"},
39 {8, nullptr, "RequestCancelHandleRequest"}, 30 {8, nullptr, "RequestCancelHandleRequest"},
40 {9, nullptr, "CancelRequest"}, 31 {9, nullptr, "CancelRequest"},
41 {10, nullptr, "GetHostByNameRequestWithOptions"}, 32 {10, &SFDNSRES::GetHostByNameRequestWithOptions, "GetHostByNameRequestWithOptions"},
42 {11, nullptr, "GetHostByAddrRequestWithOptions"}, 33 {11, nullptr, "GetHostByAddrRequestWithOptions"},
43 {12, &SFDNSRES::GetAddrInfoRequestWithOptions, "GetAddrInfoRequestWithOptions"}, 34 {12, &SFDNSRES::GetAddrInfoRequestWithOptions, "GetAddrInfoRequestWithOptions"},
44 {13, nullptr, "GetNameInfoRequestWithOptions"}, 35 {13, nullptr, "GetNameInfoRequestWithOptions"},
45 {14, nullptr, "ResolverSetOptionRequest"}, 36 {14, &SFDNSRES::ResolverSetOptionRequest, "ResolverSetOptionRequest"},
46 {15, nullptr, "ResolverGetOptionRequest"}, 37 {15, nullptr, "ResolverGetOptionRequest"},
47 }; 38 };
48 RegisterHandlers(functions); 39 RegisterHandlers(functions);
@@ -59,188 +50,246 @@ enum class NetDbError : s32 {
59 NoData = 4, 50 NoData = 4,
60}; 51};
61 52
62static NetDbError AddrInfoErrorToNetDbError(s32 result) { 53static NetDbError GetAddrInfoErrorToNetDbError(GetAddrInfoError result) {
63 // Best effort guess to map errors 54 // These combinations have been verified on console (but are not
55 // exhaustive).
64 switch (result) { 56 switch (result) {
65 case 0: 57 case GetAddrInfoError::SUCCESS:
66 return NetDbError::Success; 58 return NetDbError::Success;
67 case EAI_AGAIN: 59 case GetAddrInfoError::AGAIN:
68 return NetDbError::TryAgain; 60 return NetDbError::TryAgain;
69 case EAI_NODATA: 61 case GetAddrInfoError::NODATA:
70 return NetDbError::NoData; 62 return NetDbError::HostNotFound;
63 case GetAddrInfoError::SERVICE:
64 return NetDbError::Success;
71 default: 65 default:
72 return NetDbError::HostNotFound; 66 return NetDbError::HostNotFound;
73 } 67 }
74} 68}
75 69
76static std::vector<u8> SerializeAddrInfo(const addrinfo* addrinfo, s32 result_code, 70static Errno GetAddrInfoErrorToErrno(GetAddrInfoError result) {
71 // These combinations have been verified on console (but are not
72 // exhaustive).
73 switch (result) {
74 case GetAddrInfoError::SUCCESS:
75 // Note: Sometimes a successful lookup sets errno to EADDRNOTAVAIL for
76 // some reason, but that doesn't seem useful to implement.
77 return Errno::SUCCESS;
78 case GetAddrInfoError::AGAIN:
79 return Errno::SUCCESS;
80 case GetAddrInfoError::NODATA:
81 return Errno::SUCCESS;
82 case GetAddrInfoError::SERVICE:
83 return Errno::INVAL;
84 default:
85 return Errno::SUCCESS;
86 }
87}
88
89template <typename T>
90static void Append(std::vector<u8>& vec, T t) {
91 size_t off = vec.size();
92 vec.resize(off + sizeof(T));
93 std::memcpy(vec.data() + off, &t, sizeof(T));
94}
95
96static void AppendNulTerminated(std::vector<u8>& vec, std::string_view str) {
97 size_t off = vec.size();
98 vec.resize(off + str.size() + 1);
99 std::memcpy(vec.data() + off, str.data(), str.size());
100}
101
102// We implement gethostbyname using the host's getaddrinfo rather than the
103// host's gethostbyname, because it simplifies portability: e.g., getaddrinfo
104// behaves the same on Unix and Windows, unlike gethostbyname where Windows
105// doesn't implement h_errno.
106static std::vector<u8> SerializeAddrInfoAsHostEnt(const std::vector<Network::AddrInfo>& vec,
107 std::string_view host) {
108
109 std::vector<u8> data;
110 // h_name: use the input hostname (append nul-terminated)
111 AppendNulTerminated(data, host);
112 // h_aliases: leave empty
113
114 Append<u32_be>(data, 0); // count of h_aliases
115 // (If the count were nonzero, the aliases would be appended as nul-terminated here.)
116 Append<u16_be>(data, static_cast<u16>(Domain::INET)); // h_addrtype
117 Append<u16_be>(data, sizeof(Network::IPv4Address)); // h_length
118 // h_addr_list:
119 size_t count = vec.size();
120 ASSERT(count <= UINT32_MAX);
121 Append<u32_be>(data, static_cast<uint32_t>(count));
122 for (const Network::AddrInfo& addrinfo : vec) {
123 // On the Switch, this is passed through htonl despite already being
124 // big-endian, so it ends up as little-endian.
125 Append<u32_le>(data, Network::IPv4AddressToInteger(addrinfo.addr.ip));
126
127 LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host,
128 Network::IPv4AddressToString(addrinfo.addr.ip));
129 }
130 return data;
131}
132
133static std::pair<u32, GetAddrInfoError> GetHostByNameRequestImpl(HLERequestContext& ctx) {
134 struct Parameters {
135 u8 use_nsd_resolve;
136 u32 cancel_handle;
137 u64 process_id;
138 };
139
140 IPC::RequestParser rp{ctx};
141 const auto parameters = rp.PopRaw<Parameters>();
142
143 LOG_WARNING(
144 Service,
145 "called with ignored parameters: use_nsd_resolve={}, cancel_handle={}, process_id={}",
146 parameters.use_nsd_resolve, parameters.cancel_handle, parameters.process_id);
147
148 const auto host_buffer = ctx.ReadBuffer(0);
149 const std::string host = Common::StringFromBuffer(host_buffer);
150 // For now, ignore options, which are in input buffer 1 for GetHostByNameRequestWithOptions.
151
152 auto res = Network::GetAddrInfo(host, /*service*/ std::nullopt);
153 if (!res.has_value()) {
154 return {0, Translate(res.error())};
155 }
156
157 std::vector<u8> data = SerializeAddrInfoAsHostEnt(res.value(), host);
158 u32 data_size = static_cast<u32>(data.size());
159 ctx.WriteBuffer(data, 0);
160
161 return {data_size, GetAddrInfoError::SUCCESS};
162}
163
164void SFDNSRES::GetHostByNameRequest(HLERequestContext& ctx) {
165 auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx);
166
167 IPC::ResponseBuilder rb{ctx, 5};
168 rb.Push(ResultSuccess);
169 rb.Push(static_cast<s32>(GetAddrInfoErrorToNetDbError(emu_gai_err))); // netdb error code
170 rb.Push(static_cast<s32>(GetAddrInfoErrorToErrno(emu_gai_err))); // errno
171 rb.Push(data_size); // serialized size
172}
173
174void SFDNSRES::GetHostByNameRequestWithOptions(HLERequestContext& ctx) {
175 auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx);
176
177 IPC::ResponseBuilder rb{ctx, 5};
178 rb.Push(ResultSuccess);
179 rb.Push(data_size); // serialized size
180 rb.Push(static_cast<s32>(GetAddrInfoErrorToNetDbError(emu_gai_err))); // netdb error code
181 rb.Push(static_cast<s32>(GetAddrInfoErrorToErrno(emu_gai_err))); // errno
182}
183
184static std::vector<u8> SerializeAddrInfo(const std::vector<Network::AddrInfo>& vec,
77 std::string_view host) { 185 std::string_view host) {
78 // Adapted from 186 // Adapted from
79 // https://github.com/switchbrew/libnx/blob/c5a9a909a91657a9818a3b7e18c9b91ff0cbb6e3/nx/source/runtime/resolver.c#L190 187 // https://github.com/switchbrew/libnx/blob/c5a9a909a91657a9818a3b7e18c9b91ff0cbb6e3/nx/source/runtime/resolver.c#L190
80 std::vector<u8> data; 188 std::vector<u8> data;
81 189
82 auto* current = addrinfo; 190 for (const Network::AddrInfo& addrinfo : vec) {
83 while (current != nullptr) { 191 // serialized addrinfo:
84 struct SerializedResponseHeader { 192 Append<u32_be>(data, 0xBEEFCAFE); // magic
85 u32 magic; 193 Append<u32_be>(data, 0); // ai_flags
86 s32 flags; 194 Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.family))); // ai_family
87 s32 family; 195 Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.socket_type))); // ai_socktype
88 s32 socket_type; 196 Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.protocol))); // ai_protocol
89 s32 protocol; 197 Append<u32_be>(data, sizeof(SockAddrIn)); // ai_addrlen
90 u32 address_length; 198 // ^ *not* sizeof(SerializedSockAddrIn), not that it matters since they're the same size
91 }; 199
92 static_assert(sizeof(SerializedResponseHeader) == 0x18, 200 // ai_addr:
93 "Response header size must be 0x18 bytes"); 201 Append<u16_be>(data, static_cast<u16>(Translate(addrinfo.addr.family))); // sin_family
94 202 // On the Switch, the following fields are passed through htonl despite
95 constexpr auto header_size = sizeof(SerializedResponseHeader); 203 // already being big-endian, so they end up as little-endian.
96 const auto addr_size = 204 Append<u16_le>(data, addrinfo.addr.portno); // sin_port
97 current->ai_addr && current->ai_addrlen > 0 ? current->ai_addrlen : 4; 205 Append<u32_le>(data, Network::IPv4AddressToInteger(addrinfo.addr.ip)); // sin_addr
98 const auto canonname_size = current->ai_canonname ? strlen(current->ai_canonname) + 1 : 1; 206 data.resize(data.size() + 8, 0); // sin_zero
99 207
100 const auto last_size = data.size(); 208 if (addrinfo.canon_name.has_value()) {
101 data.resize(last_size + header_size + addr_size + canonname_size); 209 AppendNulTerminated(data, *addrinfo.canon_name);
102
103 // Header in network byte order
104 SerializedResponseHeader header{};
105
106 constexpr auto HEADER_MAGIC = 0xBEEFCAFE;
107 header.magic = htonl(HEADER_MAGIC);
108 header.family = htonl(current->ai_family);
109 header.flags = htonl(current->ai_flags);
110 header.socket_type = htonl(current->ai_socktype);
111 header.protocol = htonl(current->ai_protocol);
112 header.address_length = current->ai_addr ? htonl((u32)current->ai_addrlen) : 0;
113
114 auto* header_ptr = data.data() + last_size;
115 std::memcpy(header_ptr, &header, header_size);
116
117 if (header.address_length == 0) {
118 std::memset(header_ptr + header_size, 0, 4);
119 } else {
120 switch (current->ai_family) {
121 case AF_INET: {
122 struct SockAddrIn {
123 s16 sin_family;
124 u16 sin_port;
125 u32 sin_addr;
126 u8 sin_zero[8];
127 };
128
129 SockAddrIn serialized_addr{};
130 const auto addr = *reinterpret_cast<sockaddr_in*>(current->ai_addr);
131 serialized_addr.sin_port = htons(addr.sin_port);
132 serialized_addr.sin_family = htons(addr.sin_family);
133 serialized_addr.sin_addr = htonl(addr.sin_addr.s_addr);
134 std::memcpy(header_ptr + header_size, &serialized_addr, sizeof(SockAddrIn));
135
136 char addr_string_buf[64]{};
137 inet_ntop(AF_INET, &addr.sin_addr, addr_string_buf, std::size(addr_string_buf));
138 LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host, addr_string_buf);
139 break;
140 }
141 case AF_INET6: {
142 struct SockAddrIn6 {
143 s16 sin6_family;
144 u16 sin6_port;
145 u32 sin6_flowinfo;
146 u8 sin6_addr[16];
147 u32 sin6_scope_id;
148 };
149
150 SockAddrIn6 serialized_addr{};
151 const auto addr = *reinterpret_cast<sockaddr_in6*>(current->ai_addr);
152 serialized_addr.sin6_family = htons(addr.sin6_family);
153 serialized_addr.sin6_port = htons(addr.sin6_port);
154 serialized_addr.sin6_flowinfo = htonl(addr.sin6_flowinfo);
155 serialized_addr.sin6_scope_id = htonl(addr.sin6_scope_id);
156 std::memcpy(serialized_addr.sin6_addr, &addr.sin6_addr,
157 sizeof(SockAddrIn6::sin6_addr));
158 std::memcpy(header_ptr + header_size, &serialized_addr, sizeof(SockAddrIn6));
159
160 char addr_string_buf[64]{};
161 inet_ntop(AF_INET6, &addr.sin6_addr, addr_string_buf, std::size(addr_string_buf));
162 LOG_INFO(Service, "Resolved host '{}' to IPv6 address {}", host, addr_string_buf);
163 break;
164 }
165 default:
166 std::memcpy(header_ptr + header_size, current->ai_addr, addr_size);
167 break;
168 }
169 }
170 if (current->ai_canonname) {
171 std::memcpy(header_ptr + addr_size, current->ai_canonname, canonname_size);
172 } else { 210 } else {
173 *(header_ptr + header_size + addr_size) = 0; 211 data.push_back(0);
174 } 212 }
175 213
176 current = current->ai_next; 214 LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host,
215 Network::IPv4AddressToString(addrinfo.addr.ip));
177 } 216 }
178 217
179 // 4-byte sentinel value 218 data.resize(data.size() + 4, 0); // 4-byte sentinel value
180 data.push_back(0);
181 data.push_back(0);
182 data.push_back(0);
183 data.push_back(0);
184 219
185 return data; 220 return data;
186} 221}
187 222
188static std::pair<u32, s32> GetAddrInfoRequestImpl(HLERequestContext& ctx) { 223static std::pair<u32, GetAddrInfoError> GetAddrInfoRequestImpl(HLERequestContext& ctx) {
189 struct Parameters { 224 struct Parameters {
190 u8 use_nsd_resolve; 225 u8 use_nsd_resolve;
191 u32 unknown; 226 u32 cancel_handle;
192 u64 process_id; 227 u64 process_id;
193 }; 228 };
194 229
195 IPC::RequestParser rp{ctx}; 230 IPC::RequestParser rp{ctx};
196 const auto parameters = rp.PopRaw<Parameters>(); 231 const auto parameters = rp.PopRaw<Parameters>();
197 232
198 LOG_WARNING(Service, 233 LOG_WARNING(
199 "called with ignored parameters: use_nsd_resolve={}, unknown={}, process_id={}", 234 Service,
200 parameters.use_nsd_resolve, parameters.unknown, parameters.process_id); 235 "called with ignored parameters: use_nsd_resolve={}, cancel_handle={}, process_id={}",
236 parameters.use_nsd_resolve, parameters.cancel_handle, parameters.process_id);
237
238 // TODO: If use_nsd_resolve is true, pass the name through NSD::Resolve
239 // before looking up.
201 240
202 const auto host_buffer = ctx.ReadBuffer(0); 241 const auto host_buffer = ctx.ReadBuffer(0);
203 const std::string host = Common::StringFromBuffer(host_buffer); 242 const std::string host = Common::StringFromBuffer(host_buffer);
204 243
205 const auto service_buffer = ctx.ReadBuffer(1); 244 std::optional<std::string> service = std::nullopt;
206 const std::string service = Common::StringFromBuffer(service_buffer); 245 if (ctx.CanReadBuffer(1)) {
207 246 std::span<const u8> service_buffer = ctx.ReadBuffer(1);
208 addrinfo* addrinfo; 247 service = Common::StringFromBuffer(service_buffer);
209 // Pass null for hints. Serialized hints are also passed in a buffer, but are ignored for now 248 }
210 s32 result_code = getaddrinfo(host.c_str(), service.c_str(), nullptr, &addrinfo);
211 249
212 u32 data_size = 0; 250 // Serialized hints are also passed in a buffer, but are ignored for now.
213 if (result_code == 0 && addrinfo != nullptr) {
214 const std::vector<u8>& data = SerializeAddrInfo(addrinfo, result_code, host);
215 data_size = static_cast<u32>(data.size());
216 freeaddrinfo(addrinfo);
217 251
218 ctx.WriteBuffer(data, 0); 252 auto res = Network::GetAddrInfo(host, service);
253 if (!res.has_value()) {
254 return {0, Translate(res.error())};
219 } 255 }
220 256
221 return std::make_pair(data_size, result_code); 257 std::vector<u8> data = SerializeAddrInfo(res.value(), host);
258 u32 data_size = static_cast<u32>(data.size());
259 ctx.WriteBuffer(data, 0);
260
261 return {data_size, GetAddrInfoError::SUCCESS};
222} 262}
223 263
224void SFDNSRES::GetAddrInfoRequest(HLERequestContext& ctx) { 264void SFDNSRES::GetAddrInfoRequest(HLERequestContext& ctx) {
225 auto [data_size, result_code] = GetAddrInfoRequestImpl(ctx); 265 auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx);
226 266
227 IPC::ResponseBuilder rb{ctx, 4}; 267 IPC::ResponseBuilder rb{ctx, 5};
228 rb.Push(ResultSuccess); 268 rb.Push(ResultSuccess);
229 rb.Push(static_cast<s32>(AddrInfoErrorToNetDbError(result_code))); // NetDBErrorCode 269 rb.Push(static_cast<s32>(GetAddrInfoErrorToErrno(emu_gai_err))); // errno
230 rb.Push(result_code); // errno 270 rb.Push(static_cast<s32>(emu_gai_err)); // getaddrinfo error code
231 rb.Push(data_size); // serialized size 271 rb.Push(data_size); // serialized size
232} 272}
233 273
234void SFDNSRES::GetAddrInfoRequestWithOptions(HLERequestContext& ctx) { 274void SFDNSRES::GetAddrInfoRequestWithOptions(HLERequestContext& ctx) {
235 // Additional options are ignored 275 // Additional options are ignored
236 auto [data_size, result_code] = GetAddrInfoRequestImpl(ctx); 276 auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx);
277
278 IPC::ResponseBuilder rb{ctx, 6};
279 rb.Push(ResultSuccess);
280 rb.Push(data_size); // serialized size
281 rb.Push(static_cast<s32>(emu_gai_err)); // getaddrinfo error code
282 rb.Push(static_cast<s32>(GetAddrInfoErrorToNetDbError(emu_gai_err))); // netdb error code
283 rb.Push(static_cast<s32>(GetAddrInfoErrorToErrno(emu_gai_err))); // errno
284}
285
286void SFDNSRES::ResolverSetOptionRequest(HLERequestContext& ctx) {
287 LOG_WARNING(Service, "(STUBBED) called");
288
289 IPC::ResponseBuilder rb{ctx, 3};
237 290
238 IPC::ResponseBuilder rb{ctx, 5};
239 rb.Push(ResultSuccess); 291 rb.Push(ResultSuccess);
240 rb.Push(data_size); // serialized size 292 rb.Push<s32>(0); // bsd errno
241 rb.Push(result_code); // errno
242 rb.Push(static_cast<s32>(AddrInfoErrorToNetDbError(result_code))); // NetDBErrorCode
243 rb.Push(0);
244} 293}
245 294
246} // namespace Service::Sockets 295} // namespace Service::Sockets
diff --git a/src/core/hle/service/sockets/sfdnsres.h b/src/core/hle/service/sockets/sfdnsres.h
index 18e3cd60c..d99a9d560 100644
--- a/src/core/hle/service/sockets/sfdnsres.h
+++ b/src/core/hle/service/sockets/sfdnsres.h
@@ -17,8 +17,11 @@ public:
17 ~SFDNSRES() override; 17 ~SFDNSRES() override;
18 18
19private: 19private:
20 void GetHostByNameRequest(HLERequestContext& ctx);
21 void GetHostByNameRequestWithOptions(HLERequestContext& ctx);
20 void GetAddrInfoRequest(HLERequestContext& ctx); 22 void GetAddrInfoRequest(HLERequestContext& ctx);
21 void GetAddrInfoRequestWithOptions(HLERequestContext& ctx); 23 void GetAddrInfoRequestWithOptions(HLERequestContext& ctx);
24 void ResolverSetOptionRequest(HLERequestContext& ctx);
22}; 25};
23 26
24} // namespace Service::Sockets 27} // namespace Service::Sockets
diff --git a/src/core/hle/service/sockets/sockets.h b/src/core/hle/service/sockets/sockets.h
index acd2dae7b..77426c46e 100644
--- a/src/core/hle/service/sockets/sockets.h
+++ b/src/core/hle/service/sockets/sockets.h
@@ -22,13 +22,35 @@ enum class Errno : u32 {
22 CONNRESET = 104, 22 CONNRESET = 104,
23 NOTCONN = 107, 23 NOTCONN = 107,
24 TIMEDOUT = 110, 24 TIMEDOUT = 110,
25 INPROGRESS = 115,
26};
27
28enum class GetAddrInfoError : s32 {
29 SUCCESS = 0,
30 ADDRFAMILY = 1,
31 AGAIN = 2,
32 BADFLAGS = 3,
33 FAIL = 4,
34 FAMILY = 5,
35 MEMORY = 6,
36 NODATA = 7,
37 NONAME = 8,
38 SERVICE = 9,
39 SOCKTYPE = 10,
40 SYSTEM = 11,
41 BADHINTS = 12,
42 PROTOCOL = 13,
43 OVERFLOW_ = 14, // avoid name collision with Windows macro
44 OTHER = 15,
25}; 45};
26 46
27enum class Domain : u32 { 47enum class Domain : u32 {
48 Unspecified = 0,
28 INET = 2, 49 INET = 2,
29}; 50};
30 51
31enum class Type : u32 { 52enum class Type : u32 {
53 Unspecified = 0,
32 STREAM = 1, 54 STREAM = 1,
33 DGRAM = 2, 55 DGRAM = 2,
34 RAW = 3, 56 RAW = 3,
@@ -36,12 +58,16 @@ enum class Type : u32 {
36}; 58};
37 59
38enum class Protocol : u32 { 60enum class Protocol : u32 {
39 UNSPECIFIED = 0, 61 Unspecified = 0,
40 ICMP = 1, 62 ICMP = 1,
41 TCP = 6, 63 TCP = 6,
42 UDP = 17, 64 UDP = 17,
43}; 65};
44 66
67enum class SocketLevel : u32 {
68 SOCKET = 0xffff, // i.e. SOL_SOCKET
69};
70
45enum class OptName : u32 { 71enum class OptName : u32 {
46 REUSEADDR = 0x4, 72 REUSEADDR = 0x4,
47 KEEPALIVE = 0x8, 73 KEEPALIVE = 0x8,
@@ -51,6 +77,8 @@ enum class OptName : u32 {
51 RCVBUF = 0x1002, 77 RCVBUF = 0x1002,
52 SNDTIMEO = 0x1005, 78 SNDTIMEO = 0x1005,
53 RCVTIMEO = 0x1006, 79 RCVTIMEO = 0x1006,
80 ERROR_ = 0x1007, // avoid name collision with Windows macro
81 NOSIGPIPE = 0x800, // at least according to libnx
54}; 82};
55 83
56enum class ShutdownHow : s32 { 84enum class ShutdownHow : s32 {
@@ -80,6 +108,9 @@ enum class PollEvents : u16 {
80 Err = 1 << 3, 108 Err = 1 << 3,
81 Hup = 1 << 4, 109 Hup = 1 << 4,
82 Nval = 1 << 5, 110 Nval = 1 << 5,
111 RdNorm = 1 << 6,
112 RdBand = 1 << 7,
113 WrBand = 1 << 8,
83}; 114};
84 115
85DECLARE_ENUM_FLAG_OPERATORS(PollEvents); 116DECLARE_ENUM_FLAG_OPERATORS(PollEvents);
diff --git a/src/core/hle/service/sockets/sockets_translate.cpp b/src/core/hle/service/sockets/sockets_translate.cpp
index 594e58f90..2f9a0e39c 100644
--- a/src/core/hle/service/sockets/sockets_translate.cpp
+++ b/src/core/hle/service/sockets/sockets_translate.cpp
@@ -29,6 +29,8 @@ Errno Translate(Network::Errno value) {
29 return Errno::TIMEDOUT; 29 return Errno::TIMEDOUT;
30 case Network::Errno::CONNRESET: 30 case Network::Errno::CONNRESET:
31 return Errno::CONNRESET; 31 return Errno::CONNRESET;
32 case Network::Errno::INPROGRESS:
33 return Errno::INPROGRESS;
32 default: 34 default:
33 UNIMPLEMENTED_MSG("Unimplemented errno={}", value); 35 UNIMPLEMENTED_MSG("Unimplemented errno={}", value);
34 return Errno::SUCCESS; 36 return Errno::SUCCESS;
@@ -39,8 +41,50 @@ std::pair<s32, Errno> Translate(std::pair<s32, Network::Errno> value) {
39 return {value.first, Translate(value.second)}; 41 return {value.first, Translate(value.second)};
40} 42}
41 43
44GetAddrInfoError Translate(Network::GetAddrInfoError error) {
45 switch (error) {
46 case Network::GetAddrInfoError::SUCCESS:
47 return GetAddrInfoError::SUCCESS;
48 case Network::GetAddrInfoError::ADDRFAMILY:
49 return GetAddrInfoError::ADDRFAMILY;
50 case Network::GetAddrInfoError::AGAIN:
51 return GetAddrInfoError::AGAIN;
52 case Network::GetAddrInfoError::BADFLAGS:
53 return GetAddrInfoError::BADFLAGS;
54 case Network::GetAddrInfoError::FAIL:
55 return GetAddrInfoError::FAIL;
56 case Network::GetAddrInfoError::FAMILY:
57 return GetAddrInfoError::FAMILY;
58 case Network::GetAddrInfoError::MEMORY:
59 return GetAddrInfoError::MEMORY;
60 case Network::GetAddrInfoError::NODATA:
61 return GetAddrInfoError::NODATA;
62 case Network::GetAddrInfoError::NONAME:
63 return GetAddrInfoError::NONAME;
64 case Network::GetAddrInfoError::SERVICE:
65 return GetAddrInfoError::SERVICE;
66 case Network::GetAddrInfoError::SOCKTYPE:
67 return GetAddrInfoError::SOCKTYPE;
68 case Network::GetAddrInfoError::SYSTEM:
69 return GetAddrInfoError::SYSTEM;
70 case Network::GetAddrInfoError::BADHINTS:
71 return GetAddrInfoError::BADHINTS;
72 case Network::GetAddrInfoError::PROTOCOL:
73 return GetAddrInfoError::PROTOCOL;
74 case Network::GetAddrInfoError::OVERFLOW_:
75 return GetAddrInfoError::OVERFLOW_;
76 case Network::GetAddrInfoError::OTHER:
77 return GetAddrInfoError::OTHER;
78 default:
79 UNIMPLEMENTED_MSG("Unimplemented GetAddrInfoError={}", error);
80 return GetAddrInfoError::OTHER;
81 }
82}
83
42Network::Domain Translate(Domain domain) { 84Network::Domain Translate(Domain domain) {
43 switch (domain) { 85 switch (domain) {
86 case Domain::Unspecified:
87 return Network::Domain::Unspecified;
44 case Domain::INET: 88 case Domain::INET:
45 return Network::Domain::INET; 89 return Network::Domain::INET;
46 default: 90 default:
@@ -51,6 +95,8 @@ Network::Domain Translate(Domain domain) {
51 95
52Domain Translate(Network::Domain domain) { 96Domain Translate(Network::Domain domain) {
53 switch (domain) { 97 switch (domain) {
98 case Network::Domain::Unspecified:
99 return Domain::Unspecified;
54 case Network::Domain::INET: 100 case Network::Domain::INET:
55 return Domain::INET; 101 return Domain::INET;
56 default: 102 default:
@@ -61,39 +107,69 @@ Domain Translate(Network::Domain domain) {
61 107
62Network::Type Translate(Type type) { 108Network::Type Translate(Type type) {
63 switch (type) { 109 switch (type) {
110 case Type::Unspecified:
111 return Network::Type::Unspecified;
64 case Type::STREAM: 112 case Type::STREAM:
65 return Network::Type::STREAM; 113 return Network::Type::STREAM;
66 case Type::DGRAM: 114 case Type::DGRAM:
67 return Network::Type::DGRAM; 115 return Network::Type::DGRAM;
116 case Type::RAW:
117 return Network::Type::RAW;
118 case Type::SEQPACKET:
119 return Network::Type::SEQPACKET;
68 default: 120 default:
69 UNIMPLEMENTED_MSG("Unimplemented type={}", type); 121 UNIMPLEMENTED_MSG("Unimplemented type={}", type);
70 return Network::Type{}; 122 return Network::Type{};
71 } 123 }
72} 124}
73 125
74Network::Protocol Translate(Type type, Protocol protocol) { 126Type Translate(Network::Type type) {
127 switch (type) {
128 case Network::Type::Unspecified:
129 return Type::Unspecified;
130 case Network::Type::STREAM:
131 return Type::STREAM;
132 case Network::Type::DGRAM:
133 return Type::DGRAM;
134 case Network::Type::RAW:
135 return Type::RAW;
136 case Network::Type::SEQPACKET:
137 return Type::SEQPACKET;
138 default:
139 UNIMPLEMENTED_MSG("Unimplemented type={}", type);
140 return Type{};
141 }
142}
143
144Network::Protocol Translate(Protocol protocol) {
75 switch (protocol) { 145 switch (protocol) {
76 case Protocol::UNSPECIFIED: 146 case Protocol::Unspecified:
77 LOG_WARNING(Service, "Unspecified protocol, assuming protocol from type"); 147 return Network::Protocol::Unspecified;
78 switch (type) {
79 case Type::DGRAM:
80 return Network::Protocol::UDP;
81 case Type::STREAM:
82 return Network::Protocol::TCP;
83 default:
84 return Network::Protocol::TCP;
85 }
86 case Protocol::TCP: 148 case Protocol::TCP:
87 return Network::Protocol::TCP; 149 return Network::Protocol::TCP;
88 case Protocol::UDP: 150 case Protocol::UDP:
89 return Network::Protocol::UDP; 151 return Network::Protocol::UDP;
90 default: 152 default:
91 UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol); 153 UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol);
92 return Network::Protocol::TCP; 154 return Network::Protocol::Unspecified;
155 }
156}
157
158Protocol Translate(Network::Protocol protocol) {
159 switch (protocol) {
160 case Network::Protocol::Unspecified:
161 return Protocol::Unspecified;
162 case Network::Protocol::TCP:
163 return Protocol::TCP;
164 case Network::Protocol::UDP:
165 return Protocol::UDP;
166 default:
167 UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol);
168 return Protocol::Unspecified;
93 } 169 }
94} 170}
95 171
96Network::PollEvents TranslatePollEventsToHost(PollEvents flags) { 172Network::PollEvents Translate(PollEvents flags) {
97 Network::PollEvents result{}; 173 Network::PollEvents result{};
98 const auto translate = [&result, &flags](PollEvents from, Network::PollEvents to) { 174 const auto translate = [&result, &flags](PollEvents from, Network::PollEvents to) {
99 if (True(flags & from)) { 175 if (True(flags & from)) {
@@ -107,12 +183,15 @@ Network::PollEvents TranslatePollEventsToHost(PollEvents flags) {
107 translate(PollEvents::Err, Network::PollEvents::Err); 183 translate(PollEvents::Err, Network::PollEvents::Err);
108 translate(PollEvents::Hup, Network::PollEvents::Hup); 184 translate(PollEvents::Hup, Network::PollEvents::Hup);
109 translate(PollEvents::Nval, Network::PollEvents::Nval); 185 translate(PollEvents::Nval, Network::PollEvents::Nval);
186 translate(PollEvents::RdNorm, Network::PollEvents::RdNorm);
187 translate(PollEvents::RdBand, Network::PollEvents::RdBand);
188 translate(PollEvents::WrBand, Network::PollEvents::WrBand);
110 189
111 UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags); 190 UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags);
112 return result; 191 return result;
113} 192}
114 193
115PollEvents TranslatePollEventsToGuest(Network::PollEvents flags) { 194PollEvents Translate(Network::PollEvents flags) {
116 PollEvents result{}; 195 PollEvents result{};
117 const auto translate = [&result, &flags](Network::PollEvents from, PollEvents to) { 196 const auto translate = [&result, &flags](Network::PollEvents from, PollEvents to) {
118 if (True(flags & from)) { 197 if (True(flags & from)) {
@@ -127,13 +206,18 @@ PollEvents TranslatePollEventsToGuest(Network::PollEvents flags) {
127 translate(Network::PollEvents::Err, PollEvents::Err); 206 translate(Network::PollEvents::Err, PollEvents::Err);
128 translate(Network::PollEvents::Hup, PollEvents::Hup); 207 translate(Network::PollEvents::Hup, PollEvents::Hup);
129 translate(Network::PollEvents::Nval, PollEvents::Nval); 208 translate(Network::PollEvents::Nval, PollEvents::Nval);
209 translate(Network::PollEvents::RdNorm, PollEvents::RdNorm);
210 translate(Network::PollEvents::RdBand, PollEvents::RdBand);
211 translate(Network::PollEvents::WrBand, PollEvents::WrBand);
130 212
131 UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags); 213 UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags);
132 return result; 214 return result;
133} 215}
134 216
135Network::SockAddrIn Translate(SockAddrIn value) { 217Network::SockAddrIn Translate(SockAddrIn value) {
136 ASSERT(value.len == 0 || value.len == sizeof(value)); 218 // Note: 6 is incorrect, but can be passed by homebrew (because libnx sets
219 // sin_len to 6 when deserializing getaddrinfo results).
220 ASSERT(value.len == 0 || value.len == sizeof(value) || value.len == 6);
137 221
138 return { 222 return {
139 .family = Translate(static_cast<Domain>(value.family)), 223 .family = Translate(static_cast<Domain>(value.family)),
diff --git a/src/core/hle/service/sockets/sockets_translate.h b/src/core/hle/service/sockets/sockets_translate.h
index c93291d3e..694868b37 100644
--- a/src/core/hle/service/sockets/sockets_translate.h
+++ b/src/core/hle/service/sockets/sockets_translate.h
@@ -17,6 +17,9 @@ Errno Translate(Network::Errno value);
17/// Translate abstract return value errno pair to guest return value errno pair 17/// Translate abstract return value errno pair to guest return value errno pair
18std::pair<s32, Errno> Translate(std::pair<s32, Network::Errno> value); 18std::pair<s32, Errno> Translate(std::pair<s32, Network::Errno> value);
19 19
20/// Translate abstract getaddrinfo error to guest getaddrinfo error
21GetAddrInfoError Translate(Network::GetAddrInfoError value);
22
20/// Translate guest domain to abstract domain 23/// Translate guest domain to abstract domain
21Network::Domain Translate(Domain domain); 24Network::Domain Translate(Domain domain);
22 25
@@ -26,14 +29,20 @@ Domain Translate(Network::Domain domain);
26/// Translate guest type to abstract type 29/// Translate guest type to abstract type
27Network::Type Translate(Type type); 30Network::Type Translate(Type type);
28 31
32/// Translate abstract type to guest type
33Type Translate(Network::Type type);
34
29/// Translate guest protocol to abstract protocol 35/// Translate guest protocol to abstract protocol
30Network::Protocol Translate(Type type, Protocol protocol); 36Network::Protocol Translate(Protocol protocol);
31 37
32/// Translate abstract poll event flags to guest poll event flags 38/// Translate abstract protocol to guest protocol
33Network::PollEvents TranslatePollEventsToHost(PollEvents flags); 39Protocol Translate(Network::Protocol protocol);
34 40
35/// Translate guest poll event flags to abstract poll event flags 41/// Translate guest poll event flags to abstract poll event flags
36PollEvents TranslatePollEventsToGuest(Network::PollEvents flags); 42Network::PollEvents Translate(PollEvents flags);
43
44/// Translate abstract poll event flags to guest poll event flags
45PollEvents Translate(Network::PollEvents flags);
37 46
38/// Translate guest socket address structure to abstract socket address structure 47/// Translate guest socket address structure to abstract socket address structure
39Network::SockAddrIn Translate(SockAddrIn value); 48Network::SockAddrIn Translate(SockAddrIn value);
diff --git a/src/core/hle/service/ssl/ssl.cpp b/src/core/hle/service/ssl/ssl.cpp
index 2b99dd7ac..a3b54c7f0 100644
--- a/src/core/hle/service/ssl/ssl.cpp
+++ b/src/core/hle/service/ssl/ssl.cpp
@@ -1,10 +1,18 @@
1// SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project 1// SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later 2// SPDX-License-Identifier: GPL-2.0-or-later
3 3
4#include "common/string_util.h"
5
6#include "core/core.h"
4#include "core/hle/service/ipc_helpers.h" 7#include "core/hle/service/ipc_helpers.h"
5#include "core/hle/service/server_manager.h" 8#include "core/hle/service/server_manager.h"
6#include "core/hle/service/service.h" 9#include "core/hle/service/service.h"
10#include "core/hle/service/sm/sm.h"
11#include "core/hle/service/sockets/bsd.h"
7#include "core/hle/service/ssl/ssl.h" 12#include "core/hle/service/ssl/ssl.h"
13#include "core/hle/service/ssl/ssl_backend.h"
14#include "core/internal_network/network.h"
15#include "core/internal_network/sockets.h"
8 16
9namespace Service::SSL { 17namespace Service::SSL {
10 18
@@ -20,6 +28,18 @@ enum class ContextOption : u32 {
20 CrlImportDateCheckEnable = 1, 28 CrlImportDateCheckEnable = 1,
21}; 29};
22 30
31// This is nn::ssl::Connection::IoMode
32enum class IoMode : u32 {
33 Blocking = 1,
34 NonBlocking = 2,
35};
36
37// This is nn::ssl::sf::OptionType
38enum class OptionType : u32 {
39 DoNotCloseSocket = 0,
40 GetServerCertChain = 1,
41};
42
23// This is nn::ssl::sf::SslVersion 43// This is nn::ssl::sf::SslVersion
24struct SslVersion { 44struct SslVersion {
25 union { 45 union {
@@ -34,35 +54,42 @@ struct SslVersion {
34 }; 54 };
35}; 55};
36 56
57struct SslContextSharedData {
58 u32 connection_count = 0;
59};
60
37class ISslConnection final : public ServiceFramework<ISslConnection> { 61class ISslConnection final : public ServiceFramework<ISslConnection> {
38public: 62public:
39 explicit ISslConnection(Core::System& system_, SslVersion version) 63 explicit ISslConnection(Core::System& system_, SslVersion version,
40 : ServiceFramework{system_, "ISslConnection"}, ssl_version{version} { 64 std::shared_ptr<SslContextSharedData>& shared_data,
65 std::unique_ptr<SSLConnectionBackend>&& backend)
66 : ServiceFramework{system_, "ISslConnection"}, ssl_version{version},
67 shared_data_{shared_data}, backend_{std::move(backend)} {
41 // clang-format off 68 // clang-format off
42 static const FunctionInfo functions[] = { 69 static const FunctionInfo functions[] = {
43 {0, nullptr, "SetSocketDescriptor"}, 70 {0, &ISslConnection::SetSocketDescriptor, "SetSocketDescriptor"},
44 {1, nullptr, "SetHostName"}, 71 {1, &ISslConnection::SetHostName, "SetHostName"},
45 {2, nullptr, "SetVerifyOption"}, 72 {2, &ISslConnection::SetVerifyOption, "SetVerifyOption"},
46 {3, nullptr, "SetIoMode"}, 73 {3, &ISslConnection::SetIoMode, "SetIoMode"},
47 {4, nullptr, "GetSocketDescriptor"}, 74 {4, nullptr, "GetSocketDescriptor"},
48 {5, nullptr, "GetHostName"}, 75 {5, nullptr, "GetHostName"},
49 {6, nullptr, "GetVerifyOption"}, 76 {6, nullptr, "GetVerifyOption"},
50 {7, nullptr, "GetIoMode"}, 77 {7, nullptr, "GetIoMode"},
51 {8, nullptr, "DoHandshake"}, 78 {8, &ISslConnection::DoHandshake, "DoHandshake"},
52 {9, nullptr, "DoHandshakeGetServerCert"}, 79 {9, &ISslConnection::DoHandshakeGetServerCert, "DoHandshakeGetServerCert"},
53 {10, nullptr, "Read"}, 80 {10, &ISslConnection::Read, "Read"},
54 {11, nullptr, "Write"}, 81 {11, &ISslConnection::Write, "Write"},
55 {12, nullptr, "Pending"}, 82 {12, &ISslConnection::Pending, "Pending"},
56 {13, nullptr, "Peek"}, 83 {13, nullptr, "Peek"},
57 {14, nullptr, "Poll"}, 84 {14, nullptr, "Poll"},
58 {15, nullptr, "GetVerifyCertError"}, 85 {15, nullptr, "GetVerifyCertError"},
59 {16, nullptr, "GetNeededServerCertBufferSize"}, 86 {16, nullptr, "GetNeededServerCertBufferSize"},
60 {17, nullptr, "SetSessionCacheMode"}, 87 {17, &ISslConnection::SetSessionCacheMode, "SetSessionCacheMode"},
61 {18, nullptr, "GetSessionCacheMode"}, 88 {18, nullptr, "GetSessionCacheMode"},
62 {19, nullptr, "FlushSessionCache"}, 89 {19, nullptr, "FlushSessionCache"},
63 {20, nullptr, "SetRenegotiationMode"}, 90 {20, nullptr, "SetRenegotiationMode"},
64 {21, nullptr, "GetRenegotiationMode"}, 91 {21, nullptr, "GetRenegotiationMode"},
65 {22, nullptr, "SetOption"}, 92 {22, &ISslConnection::SetOption, "SetOption"},
66 {23, nullptr, "GetOption"}, 93 {23, nullptr, "GetOption"},
67 {24, nullptr, "GetVerifyCertErrors"}, 94 {24, nullptr, "GetVerifyCertErrors"},
68 {25, nullptr, "GetCipherInfo"}, 95 {25, nullptr, "GetCipherInfo"},
@@ -80,21 +107,295 @@ public:
80 // clang-format on 107 // clang-format on
81 108
82 RegisterHandlers(functions); 109 RegisterHandlers(functions);
110
111 shared_data->connection_count++;
112 }
113
114 ~ISslConnection() {
115 shared_data_->connection_count--;
116 if (fd_to_close_.has_value()) {
117 s32 fd = *fd_to_close_;
118 if (!do_not_close_socket_) {
119 LOG_ERROR(Service_SSL,
120 "do_not_close_socket was changed after setting socket; is this right?");
121 } else {
122 auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
123 if (bsd) {
124 auto err = bsd->CloseImpl(fd);
125 if (err != Service::Sockets::Errno::SUCCESS) {
126 LOG_ERROR(Service_SSL, "failed to close duplicated socket: {}", err);
127 }
128 }
129 }
130 }
83 } 131 }
84 132
85private: 133private:
86 SslVersion ssl_version; 134 SslVersion ssl_version;
135 std::shared_ptr<SslContextSharedData> shared_data_;
136 std::unique_ptr<SSLConnectionBackend> backend_;
137 std::optional<int> fd_to_close_;
138 bool do_not_close_socket_ = false;
139 bool get_server_cert_chain_ = false;
140 std::shared_ptr<Network::SocketBase> socket_;
141 bool did_set_host_name_ = false;
142 bool did_handshake_ = false;
143
144 ResultVal<s32> SetSocketDescriptorImpl(s32 fd) {
145 LOG_DEBUG(Service_SSL, "called, fd={}", fd);
146 ASSERT(!did_handshake_);
147 auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
148 ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; });
149 s32 ret_fd;
150 // Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor
151 if (do_not_close_socket_) {
152 auto res = bsd->DuplicateSocketImpl(fd);
153 if (!res.has_value()) {
154 LOG_ERROR(Service_SSL, "failed to duplicate socket");
155 return ResultInvalidSocket;
156 }
157 fd = *res;
158 fd_to_close_ = fd;
159 ret_fd = fd;
160 } else {
161 ret_fd = -1;
162 }
163 std::optional<std::shared_ptr<Network::SocketBase>> sock = bsd->GetSocket(fd);
164 if (!sock.has_value()) {
165 LOG_ERROR(Service_SSL, "invalid socket fd {}", fd);
166 return ResultInvalidSocket;
167 }
168 socket_ = std::move(*sock);
169 backend_->SetSocket(socket_);
170 return ret_fd;
171 }
172
173 Result SetHostNameImpl(const std::string& hostname) {
174 LOG_DEBUG(Service_SSL, "SetHostNameImpl({})", hostname);
175 ASSERT(!did_handshake_);
176 Result res = backend_->SetHostName(hostname);
177 if (res == ResultSuccess) {
178 did_set_host_name_ = true;
179 }
180 return res;
181 }
182
183 Result SetVerifyOptionImpl(u32 option) {
184 ASSERT(!did_handshake_);
185 LOG_WARNING(Service_SSL, "(STUBBED) called. option={}", option);
186 return ResultSuccess;
187 }
188
189 Result SetIOModeImpl(u32 _mode) {
190 auto mode = static_cast<IoMode>(_mode);
191 ASSERT(mode == IoMode::Blocking || mode == IoMode::NonBlocking);
192 ASSERT_OR_EXECUTE(socket_, { return ResultNoSocket; });
193
194 bool non_block = mode == IoMode::NonBlocking;
195 Network::Errno e = socket_->SetNonBlock(non_block);
196 if (e != Network::Errno::SUCCESS) {
197 LOG_ERROR(Service_SSL, "Failed to set native socket non-block flag to {}", non_block);
198 }
199 return ResultSuccess;
200 }
201
202 Result SetSessionCacheModeImpl(u32 mode) {
203 ASSERT(!did_handshake_);
204 LOG_WARNING(Service_SSL, "(STUBBED) called. value={}", mode);
205 return ResultSuccess;
206 }
207
208 Result DoHandshakeImpl() {
209 ASSERT_OR_EXECUTE(!did_handshake_ && socket_, { return ResultNoSocket; });
210 ASSERT_OR_EXECUTE_MSG(
211 did_set_host_name_, { return ResultInternalError; },
212 "Expected SetHostName before DoHandshake");
213 Result res = backend_->DoHandshake();
214 did_handshake_ = res.IsSuccess();
215 return res;
216 }
217
218 std::vector<u8> SerializeServerCerts(const std::vector<std::vector<u8>>& certs) {
219 struct Header {
220 u64 magic;
221 u32 count;
222 u32 pad;
223 };
224 struct EntryHeader {
225 u32 size;
226 u32 offset;
227 };
228 if (!get_server_cert_chain_) {
229 // Just return the first one, unencoded.
230 ASSERT_OR_EXECUTE_MSG(
231 !certs.empty(), { return {}; }, "Should be at least one server cert");
232 return certs[0];
233 }
234 std::vector<u8> ret;
235 Header header{0x4E4D684374726543, static_cast<u32>(certs.size()), 0};
236 ret.insert(ret.end(), reinterpret_cast<u8*>(&header), reinterpret_cast<u8*>(&header + 1));
237 size_t data_offset = sizeof(Header) + certs.size() * sizeof(EntryHeader);
238 for (auto& cert : certs) {
239 EntryHeader entry_header{static_cast<u32>(cert.size()), static_cast<u32>(data_offset)};
240 data_offset += cert.size();
241 ret.insert(ret.end(), reinterpret_cast<u8*>(&entry_header),
242 reinterpret_cast<u8*>(&entry_header + 1));
243 }
244 for (auto& cert : certs) {
245 ret.insert(ret.end(), cert.begin(), cert.end());
246 }
247 return ret;
248 }
249
250 ResultVal<std::vector<u8>> ReadImpl(size_t size) {
251 ASSERT_OR_EXECUTE(did_handshake_, { return ResultInternalError; });
252 std::vector<u8> res(size);
253 ResultVal<size_t> actual = backend_->Read(res);
254 if (actual.Failed()) {
255 return actual.Code();
256 }
257 res.resize(*actual);
258 return res;
259 }
260
261 ResultVal<size_t> WriteImpl(std::span<const u8> data) {
262 ASSERT_OR_EXECUTE(did_handshake_, { return ResultInternalError; });
263 return backend_->Write(data);
264 }
265
266 ResultVal<s32> PendingImpl() {
267 LOG_WARNING(Service_SSL, "(STUBBED) called.");
268 return 0;
269 }
270
271 void SetSocketDescriptor(HLERequestContext& ctx) {
272 IPC::RequestParser rp{ctx};
273 const s32 fd = rp.Pop<s32>();
274 const ResultVal<s32> res = SetSocketDescriptorImpl(fd);
275 IPC::ResponseBuilder rb{ctx, 3};
276 rb.Push(res.Code());
277 rb.Push<s32>(res.ValueOr(-1));
278 }
279
280 void SetHostName(HLERequestContext& ctx) {
281 const std::string hostname = Common::StringFromBuffer(ctx.ReadBuffer());
282 const Result res = SetHostNameImpl(hostname);
283 IPC::ResponseBuilder rb{ctx, 2};
284 rb.Push(res);
285 }
286
287 void SetVerifyOption(HLERequestContext& ctx) {
288 IPC::RequestParser rp{ctx};
289 const u32 option = rp.Pop<u32>();
290 const Result res = SetVerifyOptionImpl(option);
291 IPC::ResponseBuilder rb{ctx, 2};
292 rb.Push(res);
293 }
294
295 void SetIoMode(HLERequestContext& ctx) {
296 IPC::RequestParser rp{ctx};
297 const u32 mode = rp.Pop<u32>();
298 const Result res = SetIOModeImpl(mode);
299 IPC::ResponseBuilder rb{ctx, 2};
300 rb.Push(res);
301 }
302
303 void DoHandshake(HLERequestContext& ctx) {
304 const Result res = DoHandshakeImpl();
305 IPC::ResponseBuilder rb{ctx, 2};
306 rb.Push(res);
307 }
308
309 void DoHandshakeGetServerCert(HLERequestContext& ctx) {
310 Result res = DoHandshakeImpl();
311 u32 certs_count = 0;
312 u32 certs_size = 0;
313 if (res == ResultSuccess) {
314 auto certs = backend_->GetServerCerts();
315 if (certs.Succeeded()) {
316 std::vector<u8> certs_buf = SerializeServerCerts(*certs);
317 ctx.WriteBuffer(certs_buf);
318 certs_count = static_cast<u32>(certs->size());
319 certs_size = static_cast<u32>(certs_buf.size());
320 }
321 }
322 IPC::ResponseBuilder rb{ctx, 4};
323 rb.Push(res);
324 rb.Push(certs_size);
325 rb.Push(certs_count);
326 }
327
328 void Read(HLERequestContext& ctx) {
329 const ResultVal<std::vector<u8>> res = ReadImpl(ctx.GetWriteBufferSize());
330 IPC::ResponseBuilder rb{ctx, 3};
331 rb.Push(res.Code());
332 if (res.Succeeded()) {
333 rb.Push(static_cast<u32>(res->size()));
334 ctx.WriteBuffer(*res);
335 } else {
336 rb.Push(static_cast<u32>(0));
337 }
338 }
339
340 void Write(HLERequestContext& ctx) {
341 const ResultVal<size_t> res = WriteImpl(ctx.ReadBuffer());
342 IPC::ResponseBuilder rb{ctx, 3};
343 rb.Push(res.Code());
344 rb.Push(static_cast<u32>(res.ValueOr(0)));
345 }
346
347 void Pending(HLERequestContext& ctx) {
348 const ResultVal<s32> res = PendingImpl();
349 IPC::ResponseBuilder rb{ctx, 3};
350 rb.Push(res.Code());
351 rb.Push<s32>(res.ValueOr(0));
352 }
353
354 void SetSessionCacheMode(HLERequestContext& ctx) {
355 IPC::RequestParser rp{ctx};
356 const u32 mode = rp.Pop<u32>();
357 const Result res = SetSessionCacheModeImpl(mode);
358 IPC::ResponseBuilder rb{ctx, 2};
359 rb.Push(res);
360 }
361
362 void SetOption(HLERequestContext& ctx) {
363 struct Parameters {
364 OptionType option;
365 s32 value;
366 };
367 static_assert(sizeof(Parameters) == 0x8, "Parameters is an invalid size");
368
369 IPC::RequestParser rp{ctx};
370 const auto parameters = rp.PopRaw<Parameters>();
371
372 switch (parameters.option) {
373 case OptionType::DoNotCloseSocket:
374 do_not_close_socket_ = static_cast<bool>(parameters.value);
375 break;
376 case OptionType::GetServerCertChain:
377 get_server_cert_chain_ = static_cast<bool>(parameters.value);
378 break;
379 default:
380 LOG_WARNING(Service_SSL, "unrecognized option={}, value={}", parameters.option,
381 parameters.value);
382 }
383
384 IPC::ResponseBuilder rb{ctx, 2};
385 rb.Push(ResultSuccess);
386 }
87}; 387};
88 388
89class ISslContext final : public ServiceFramework<ISslContext> { 389class ISslContext final : public ServiceFramework<ISslContext> {
90public: 390public:
91 explicit ISslContext(Core::System& system_, SslVersion version) 391 explicit ISslContext(Core::System& system_, SslVersion version)
92 : ServiceFramework{system_, "ISslContext"}, ssl_version{version} { 392 : ServiceFramework{system_, "ISslContext"}, ssl_version{version},
393 shared_data_{std::make_shared<SslContextSharedData>()} {
93 static const FunctionInfo functions[] = { 394 static const FunctionInfo functions[] = {
94 {0, &ISslContext::SetOption, "SetOption"}, 395 {0, &ISslContext::SetOption, "SetOption"},
95 {1, nullptr, "GetOption"}, 396 {1, nullptr, "GetOption"},
96 {2, &ISslContext::CreateConnection, "CreateConnection"}, 397 {2, &ISslContext::CreateConnection, "CreateConnection"},
97 {3, nullptr, "GetConnectionCount"}, 398 {3, &ISslContext::GetConnectionCount, "GetConnectionCount"},
98 {4, &ISslContext::ImportServerPki, "ImportServerPki"}, 399 {4, &ISslContext::ImportServerPki, "ImportServerPki"},
99 {5, &ISslContext::ImportClientPki, "ImportClientPki"}, 400 {5, &ISslContext::ImportClientPki, "ImportClientPki"},
100 {6, nullptr, "RemoveServerPki"}, 401 {6, nullptr, "RemoveServerPki"},
@@ -111,6 +412,7 @@ public:
111 412
112private: 413private:
113 SslVersion ssl_version; 414 SslVersion ssl_version;
415 std::shared_ptr<SslContextSharedData> shared_data_;
114 416
115 void SetOption(HLERequestContext& ctx) { 417 void SetOption(HLERequestContext& ctx) {
116 struct Parameters { 418 struct Parameters {
@@ -130,11 +432,24 @@ private:
130 } 432 }
131 433
132 void CreateConnection(HLERequestContext& ctx) { 434 void CreateConnection(HLERequestContext& ctx) {
133 LOG_WARNING(Service_SSL, "(STUBBED) called"); 435 LOG_WARNING(Service_SSL, "called");
436
437 auto backend_res = CreateSSLConnectionBackend();
134 438
135 IPC::ResponseBuilder rb{ctx, 2, 0, 1}; 439 IPC::ResponseBuilder rb{ctx, 2, 0, 1};
440 rb.Push(backend_res.Code());
441 if (backend_res.Succeeded()) {
442 rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data_,
443 std::move(*backend_res));
444 }
445 }
446
447 void GetConnectionCount(HLERequestContext& ctx) {
448 LOG_WARNING(Service_SSL, "connection_count={}", shared_data_->connection_count);
449
450 IPC::ResponseBuilder rb{ctx, 3};
136 rb.Push(ResultSuccess); 451 rb.Push(ResultSuccess);
137 rb.PushIpcInterface<ISslConnection>(system, ssl_version); 452 rb.Push(shared_data_->connection_count);
138 } 453 }
139 454
140 void ImportServerPki(HLERequestContext& ctx) { 455 void ImportServerPki(HLERequestContext& ctx) {
diff --git a/src/core/hle/service/ssl/ssl_backend.h b/src/core/hle/service/ssl/ssl_backend.h
new file mode 100644
index 000000000..624e07d41
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend.h
@@ -0,0 +1,44 @@
1// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later
3
4#pragma once
5
6#include "core/hle/result.h"
7
8#include "common/common_types.h"
9
10#include <memory>
11#include <span>
12#include <string>
13#include <vector>
14
15namespace Network {
16class SocketBase;
17}
18
19namespace Service::SSL {
20
21constexpr Result ResultNoSocket{ErrorModule::SSLSrv, 103};
22constexpr Result ResultInvalidSocket{ErrorModule::SSLSrv, 106};
23constexpr Result ResultTimeout{ErrorModule::SSLSrv, 205};
24constexpr Result ResultInternalError{ErrorModule::SSLSrv, 999}; // made up
25
26constexpr Result ResultWouldBlock{ErrorModule::SSLSrv, 204};
27// ^ ResultWouldBlock is returned from Read and Write, and oddly, DoHandshake,
28// with no way in the latter case to distinguish whether the client should poll
29// for read or write. The one official client I've seen handles this by always
30// polling for read (with a timeout).
31
32class SSLConnectionBackend {
33public:
34 virtual void SetSocket(std::shared_ptr<Network::SocketBase> socket) = 0;
35 virtual Result SetHostName(const std::string& hostname) = 0;
36 virtual Result DoHandshake() = 0;
37 virtual ResultVal<size_t> Read(std::span<u8> data) = 0;
38 virtual ResultVal<size_t> Write(std::span<const u8> data) = 0;
39 virtual ResultVal<std::vector<std::vector<u8>>> GetServerCerts() = 0;
40};
41
42ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend();
43
44} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_none.cpp b/src/core/hle/service/ssl/ssl_backend_none.cpp
new file mode 100644
index 000000000..eb01561e2
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_none.cpp
@@ -0,0 +1,15 @@
1// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later
3
4#include "core/hle/service/ssl/ssl_backend.h"
5
6#include "common/logging/log.h"
7
8namespace Service::SSL {
9
10ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
11 LOG_ERROR(Service_SSL, "No SSL backend on this platform");
12 return ResultInternalError;
13}
14
15} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_openssl.cpp b/src/core/hle/service/ssl/ssl_backend_openssl.cpp
new file mode 100644
index 000000000..cf9b904ac
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_openssl.cpp
@@ -0,0 +1,342 @@
1// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later
3
4#include "core/hle/service/ssl/ssl_backend.h"
5#include "core/internal_network/network.h"
6#include "core/internal_network/sockets.h"
7
8#include "common/fs/file.h"
9#include "common/hex_util.h"
10#include "common/string_util.h"
11
12#include <mutex>
13
14#include <openssl/bio.h>
15#include <openssl/err.h>
16#include <openssl/ssl.h>
17#include <openssl/x509.h>
18
19using namespace Common::FS;
20
21namespace Service::SSL {
22
23// Import OpenSSL's `SSL` type into the namespace. This is needed because the
24// namespace is also named `SSL`.
25using ::SSL;
26
27namespace {
28
29std::once_flag one_time_init_flag;
30bool one_time_init_success = false;
31
32SSL_CTX* ssl_ctx;
33IOFile key_log_file; // only open if SSLKEYLOGFILE set in environment
34BIO_METHOD* bio_meth;
35
36Result CheckOpenSSLErrors();
37void OneTimeInit();
38void OneTimeInitLogFile();
39bool OneTimeInitBIO();
40
41} // namespace
42
43class SSLConnectionBackendOpenSSL final : public SSLConnectionBackend {
44public:
45 Result Init() {
46 std::call_once(one_time_init_flag, OneTimeInit);
47
48 if (!one_time_init_success) {
49 LOG_ERROR(Service_SSL,
50 "Can't create SSL connection because OpenSSL one-time initialization failed");
51 return ResultInternalError;
52 }
53
54 ssl_ = SSL_new(ssl_ctx);
55 if (!ssl_) {
56 LOG_ERROR(Service_SSL, "SSL_new failed");
57 return CheckOpenSSLErrors();
58 }
59
60 SSL_set_connect_state(ssl_);
61
62 bio_ = BIO_new(bio_meth);
63 if (!bio_) {
64 LOG_ERROR(Service_SSL, "BIO_new failed");
65 return CheckOpenSSLErrors();
66 }
67
68 BIO_set_data(bio_, this);
69 BIO_set_init(bio_, 1);
70 SSL_set_bio(ssl_, bio_, bio_);
71
72 return ResultSuccess;
73 }
74
75 void SetSocket(std::shared_ptr<Network::SocketBase> socket) override {
76 socket_ = socket;
77 }
78
79 Result SetHostName(const std::string& hostname) override {
80 if (!SSL_set1_host(ssl_, hostname.c_str())) { // hostname for verification
81 LOG_ERROR(Service_SSL, "SSL_set1_host({}) failed", hostname);
82 return CheckOpenSSLErrors();
83 }
84 if (!SSL_set_tlsext_host_name(ssl_, hostname.c_str())) { // hostname for SNI
85 LOG_ERROR(Service_SSL, "SSL_set_tlsext_host_name({}) failed", hostname);
86 return CheckOpenSSLErrors();
87 }
88 return ResultSuccess;
89 }
90
91 Result DoHandshake() override {
92 SSL_set_verify_result(ssl_, X509_V_OK);
93 int ret = SSL_do_handshake(ssl_);
94 long verify_result = SSL_get_verify_result(ssl_);
95 if (verify_result != X509_V_OK) {
96 LOG_ERROR(Service_SSL, "SSL cert verification failed because: {}",
97 X509_verify_cert_error_string(verify_result));
98 return CheckOpenSSLErrors();
99 }
100 if (ret <= 0) {
101 int ssl_err = SSL_get_error(ssl_, ret);
102 if (ssl_err == SSL_ERROR_ZERO_RETURN ||
103 (ssl_err == SSL_ERROR_SYSCALL && got_read_eof_)) {
104 LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
105 return ResultInternalError;
106 }
107 }
108 return HandleReturn("SSL_do_handshake", 0, ret).Code();
109 }
110
111 ResultVal<size_t> Read(std::span<u8> data) override {
112 size_t actual;
113 int ret = SSL_read_ex(ssl_, data.data(), data.size(), &actual);
114 return HandleReturn("SSL_read_ex", actual, ret);
115 }
116
117 ResultVal<size_t> Write(std::span<const u8> data) override {
118 size_t actual;
119 int ret = SSL_write_ex(ssl_, data.data(), data.size(), &actual);
120 return HandleReturn("SSL_write_ex", actual, ret);
121 }
122
123 ResultVal<size_t> HandleReturn(const char* what, size_t actual, int ret) {
124 int ssl_err = SSL_get_error(ssl_, ret);
125 CheckOpenSSLErrors();
126 switch (ssl_err) {
127 case SSL_ERROR_NONE:
128 return actual;
129 case SSL_ERROR_ZERO_RETURN:
130 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_ZERO_RETURN", what);
131 // DoHandshake special-cases this, but for Read and Write:
132 return size_t(0);
133 case SSL_ERROR_WANT_READ:
134 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_READ", what);
135 return ResultWouldBlock;
136 case SSL_ERROR_WANT_WRITE:
137 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_WRITE", what);
138 return ResultWouldBlock;
139 default:
140 if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof_) {
141 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what);
142 return size_t(0);
143 }
144 LOG_ERROR(Service_SSL, "{} => other SSL_get_error return value {}", what, ssl_err);
145 return ResultInternalError;
146 }
147 }
148
149 ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
150 STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl_);
151 if (!chain) {
152 LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr");
153 return ResultInternalError;
154 }
155 std::vector<std::vector<u8>> ret;
156 int count = sk_X509_num(chain);
157 ASSERT(count >= 0);
158 for (int i = 0; i < count; i++) {
159 X509* x509 = sk_X509_value(chain, i);
160 ASSERT_OR_EXECUTE(x509 != nullptr, { continue; });
161 unsigned char* buf = nullptr;
162 int len = i2d_X509(x509, &buf);
163 ASSERT_OR_EXECUTE(len >= 0 && buf, { continue; });
164 ret.emplace_back(buf, buf + len);
165 OPENSSL_free(buf);
166 }
167 return ret;
168 }
169
170 ~SSLConnectionBackendOpenSSL() {
171 // these are null-tolerant:
172 SSL_free(ssl_);
173 BIO_free(bio_);
174 }
175
176 static void KeyLogCallback(const SSL* ssl, const char* line) {
177 std::string str(line);
178 str.push_back('\n');
179 // Do this in a single WriteString for atomicity if multiple instances
180 // are running on different threads (though that can't currently
181 // happen).
182 if (key_log_file.WriteString(str) != str.size() || !key_log_file.Flush()) {
183 LOG_CRITICAL(Service_SSL, "Failed to write to SSLKEYLOGFILE");
184 }
185 LOG_DEBUG(Service_SSL, "Wrote to SSLKEYLOGFILE: {}", line);
186 }
187
188 static int WriteCallback(BIO* bio, const char* buf, size_t len, size_t* actual_p) {
189 auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
190 ASSERT_OR_EXECUTE_MSG(
191 self->socket_, { return 0; }, "OpenSSL asked to send but we have no socket");
192 BIO_clear_retry_flags(bio);
193 auto [actual, err] = self->socket_->Send({reinterpret_cast<const u8*>(buf), len}, 0);
194 switch (err) {
195 case Network::Errno::SUCCESS:
196 *actual_p = actual;
197 return 1;
198 case Network::Errno::AGAIN:
199 BIO_set_flags(bio, BIO_FLAGS_WRITE | BIO_FLAGS_SHOULD_RETRY);
200 return 0;
201 default:
202 LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err);
203 return -1;
204 }
205 }
206
207 static int ReadCallback(BIO* bio, char* buf, size_t len, size_t* actual_p) {
208 auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
209 ASSERT_OR_EXECUTE_MSG(
210 self->socket_, { return 0; }, "OpenSSL asked to recv but we have no socket");
211 BIO_clear_retry_flags(bio);
212 auto [actual, err] = self->socket_->Recv(0, {reinterpret_cast<u8*>(buf), len});
213 switch (err) {
214 case Network::Errno::SUCCESS:
215 *actual_p = actual;
216 if (actual == 0) {
217 self->got_read_eof_ = true;
218 }
219 return actual ? 1 : 0;
220 case Network::Errno::AGAIN:
221 BIO_set_flags(bio, BIO_FLAGS_READ | BIO_FLAGS_SHOULD_RETRY);
222 return 0;
223 default:
224 LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err);
225 return -1;
226 }
227 }
228
229 static long CtrlCallback(BIO* bio, int cmd, long larg, void* parg) {
230 switch (cmd) {
231 case BIO_CTRL_FLUSH:
232 // Nothing to flush.
233 return 1;
234 case BIO_CTRL_PUSH:
235 case BIO_CTRL_POP:
236 case BIO_CTRL_GET_KTLS_SEND:
237 case BIO_CTRL_GET_KTLS_RECV:
238 // We don't support these operations, but don't bother logging them
239 // as they're nothing unusual.
240 return 0;
241 default:
242 LOG_DEBUG(Service_SSL, "OpenSSL BIO got ctrl({}, {}, {})", cmd, larg, parg);
243 return 0;
244 }
245 }
246
247 SSL* ssl_ = nullptr;
248 BIO* bio_ = nullptr;
249 bool got_read_eof_ = false;
250
251 std::shared_ptr<Network::SocketBase> socket_;
252};
253
254ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
255 auto conn = std::make_unique<SSLConnectionBackendOpenSSL>();
256 Result res = conn->Init();
257 if (res.IsFailure()) {
258 return res;
259 }
260 return conn;
261}
262
263namespace {
264
265Result CheckOpenSSLErrors() {
266 unsigned long rc;
267 const char* file;
268 int line;
269 const char* func;
270 const char* data;
271 int flags;
272 while ((rc = ERR_get_error_all(&file, &line, &func, &data, &flags))) {
273 std::string msg;
274 msg.resize(1024, '\0');
275 ERR_error_string_n(rc, msg.data(), msg.size());
276 msg.resize(strlen(msg.data()), '\0');
277 if (flags & ERR_TXT_STRING) {
278 msg.append(" | ");
279 msg.append(data);
280 }
281 Common::Log::FmtLogMessage(Common::Log::Class::Service_SSL, Common::Log::Level::Error,
282 Common::Log::TrimSourcePath(file), line, func, "OpenSSL: {}",
283 msg);
284 }
285 return ResultInternalError;
286}
287
288void OneTimeInit() {
289 ssl_ctx = SSL_CTX_new(TLS_client_method());
290 if (!ssl_ctx) {
291 LOG_ERROR(Service_SSL, "SSL_CTX_new failed");
292 CheckOpenSSLErrors();
293 return;
294 }
295
296 SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, nullptr);
297
298 if (!SSL_CTX_set_default_verify_paths(ssl_ctx)) {
299 LOG_ERROR(Service_SSL, "SSL_CTX_set_default_verify_paths failed");
300 CheckOpenSSLErrors();
301 return;
302 }
303
304 OneTimeInitLogFile();
305
306 if (!OneTimeInitBIO()) {
307 return;
308 }
309
310 one_time_init_success = true;
311}
312
313void OneTimeInitLogFile() {
314 const char* logfile = getenv("SSLKEYLOGFILE");
315 if (logfile) {
316 key_log_file.Open(logfile, FileAccessMode::Append, FileType::TextFile,
317 FileShareFlag::ShareWriteOnly);
318 if (key_log_file.IsOpen()) {
319 SSL_CTX_set_keylog_callback(ssl_ctx, &SSLConnectionBackendOpenSSL::KeyLogCallback);
320 } else {
321 LOG_CRITICAL(Service_SSL,
322 "SSLKEYLOGFILE was set but file could not be opened; not logging keys!");
323 }
324 }
325}
326
327bool OneTimeInitBIO() {
328 bio_meth =
329 BIO_meth_new(BIO_get_new_index() | BIO_TYPE_SOURCE_SINK, "SSLConnectionBackendOpenSSL");
330 if (!bio_meth ||
331 !BIO_meth_set_write_ex(bio_meth, &SSLConnectionBackendOpenSSL::WriteCallback) ||
332 !BIO_meth_set_read_ex(bio_meth, &SSLConnectionBackendOpenSSL::ReadCallback) ||
333 !BIO_meth_set_ctrl(bio_meth, &SSLConnectionBackendOpenSSL::CtrlCallback)) {
334 LOG_ERROR(Service_SSL, "Failed to create BIO_METHOD");
335 return false;
336 }
337 return true;
338}
339
340} // namespace
341
342} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_schannel.cpp b/src/core/hle/service/ssl/ssl_backend_schannel.cpp
new file mode 100644
index 000000000..0a326b536
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_schannel.cpp
@@ -0,0 +1,529 @@
1// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later
3
4#include "core/hle/service/ssl/ssl_backend.h"
5#include "core/internal_network/network.h"
6#include "core/internal_network/sockets.h"
7
8#include "common/error.h"
9#include "common/fs/file.h"
10#include "common/hex_util.h"
11#include "common/string_util.h"
12
13#include <mutex>
14
15#define SECURITY_WIN32
16#include <Security.h>
17#include <schnlsp.h>
18
19namespace {
20
21std::once_flag one_time_init_flag;
22bool one_time_init_success = false;
23
24SCHANNEL_CRED schannel_cred{
25 .dwVersion = SCHANNEL_CRED_VERSION,
26 .dwFlags = SCH_USE_STRONG_CRYPTO | // don't allow insecure protocols
27 SCH_CRED_AUTO_CRED_VALIDATION | // validate certs
28 SCH_CRED_NO_DEFAULT_CREDS, // don't automatically present a client certificate
29 // ^ I'm assuming that nobody would want to connect Yuzu to a
30 // service that requires some OS-provided corporate client
31 // certificate, and presenting one to some arbitrary server
32 // might be a privacy concern? Who knows, though.
33};
34
35CredHandle cred_handle;
36
37static void OneTimeInit() {
38 SECURITY_STATUS ret =
39 AcquireCredentialsHandle(nullptr, const_cast<LPTSTR>(UNISP_NAME), SECPKG_CRED_OUTBOUND,
40 nullptr, &schannel_cred, nullptr, nullptr, &cred_handle, nullptr);
41 if (ret != SEC_E_OK) {
42 // SECURITY_STATUS codes are a type of HRESULT and can be used with NativeErrorToString.
43 LOG_ERROR(Service_SSL, "AcquireCredentialsHandle failed: {}",
44 Common::NativeErrorToString(ret));
45 return;
46 }
47
48 one_time_init_success = true;
49}
50
51} // namespace
52
53namespace Service::SSL {
54
55class SSLConnectionBackendSchannel final : public SSLConnectionBackend {
56public:
57 Result Init() {
58 std::call_once(one_time_init_flag, OneTimeInit);
59
60 if (!one_time_init_success) {
61 LOG_ERROR(
62 Service_SSL,
63 "Can't create SSL connection because Schannel one-time initialization failed");
64 return ResultInternalError;
65 }
66
67 return ResultSuccess;
68 }
69
70 void SetSocket(std::shared_ptr<Network::SocketBase> socket) override {
71 socket_ = socket;
72 }
73
74 Result SetHostName(const std::string& hostname) override {
75 hostname_ = hostname;
76 return ResultSuccess;
77 }
78
79 Result DoHandshake() override {
80 while (1) {
81 Result r;
82 switch (handshake_state_) {
83 case HandshakeState::Initial:
84 if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
85 (r = CallInitializeSecurityContext()) != ResultSuccess) {
86 return r;
87 }
88 // CallInitializeSecurityContext updated `handshake_state_`.
89 continue;
90 case HandshakeState::ContinueNeeded:
91 case HandshakeState::IncompleteMessage:
92 if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
93 (r = FillCiphertextReadBuf()) != ResultSuccess) {
94 return r;
95 }
96 if (ciphertext_read_buf_.empty()) {
97 LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
98 return ResultInternalError;
99 }
100 if ((r = CallInitializeSecurityContext()) != ResultSuccess) {
101 return r;
102 }
103 // CallInitializeSecurityContext updated `handshake_state_`.
104 continue;
105 case HandshakeState::DoneAfterFlush:
106 if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) {
107 return r;
108 }
109 handshake_state_ = HandshakeState::Connected;
110 return ResultSuccess;
111 case HandshakeState::Connected:
112 LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook");
113 return ResultInternalError;
114 case HandshakeState::Error:
115 return ResultInternalError;
116 }
117 }
118 }
119
120 Result FillCiphertextReadBuf() {
121 size_t fill_size = read_buf_fill_size_ ? read_buf_fill_size_ : 4096;
122 read_buf_fill_size_ = 0;
123 // This unnecessarily zeroes the buffer; oh well.
124 size_t offset = ciphertext_read_buf_.size();
125 ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; });
126 ciphertext_read_buf_.resize(offset + fill_size, 0);
127 auto read_span = std::span(ciphertext_read_buf_).subspan(offset, fill_size);
128 auto [actual, err] = socket_->Recv(0, read_span);
129 switch (err) {
130 case Network::Errno::SUCCESS:
131 ASSERT(static_cast<size_t>(actual) <= fill_size);
132 ciphertext_read_buf_.resize(offset + actual);
133 return ResultSuccess;
134 case Network::Errno::AGAIN:
135 ciphertext_read_buf_.resize(offset);
136 return ResultWouldBlock;
137 default:
138 ciphertext_read_buf_.resize(offset);
139 LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err);
140 return ResultInternalError;
141 }
142 }
143
144 // Returns success if the write buffer has been completely emptied.
145 Result FlushCiphertextWriteBuf() {
146 while (!ciphertext_write_buf_.empty()) {
147 auto [actual, err] = socket_->Send(ciphertext_write_buf_, 0);
148 switch (err) {
149 case Network::Errno::SUCCESS:
150 ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf_.size());
151 ciphertext_write_buf_.erase(ciphertext_write_buf_.begin(),
152 ciphertext_write_buf_.begin() + actual);
153 break;
154 case Network::Errno::AGAIN:
155 return ResultWouldBlock;
156 default:
157 LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err);
158 return ResultInternalError;
159 }
160 }
161 return ResultSuccess;
162 }
163
164 Result CallInitializeSecurityContext() {
165 unsigned long req = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | ISC_REQ_INTEGRITY |
166 ISC_REQ_REPLAY_DETECT | ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM |
167 ISC_REQ_USE_SUPPLIED_CREDS;
168 unsigned long attr;
169 // https://learn.microsoft.com/en-us/windows/win32/secauthn/initializesecuritycontext--schannel
170 std::array<SecBuffer, 2> input_buffers{{
171 // only used if `initial_call_done`
172 {
173 // [0]
174 .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf_.size()),
175 .BufferType = SECBUFFER_TOKEN,
176 .pvBuffer = ciphertext_read_buf_.data(),
177 },
178 {
179 // [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is
180 // returned, or SECBUFFER_EXTRA when SEC_E_CONTINUE_NEEDED is returned if the
181 // whole buffer wasn't used)
182 .BufferType = SECBUFFER_EMPTY,
183 },
184 }};
185 std::array<SecBuffer, 2> output_buffers{{
186 {
187 .BufferType = SECBUFFER_TOKEN,
188 }, // [0]
189 {
190 .BufferType = SECBUFFER_ALERT,
191 }, // [1]
192 }};
193 SecBufferDesc input_desc{
194 .ulVersion = SECBUFFER_VERSION,
195 .cBuffers = static_cast<unsigned long>(input_buffers.size()),
196 .pBuffers = input_buffers.data(),
197 };
198 SecBufferDesc output_desc{
199 .ulVersion = SECBUFFER_VERSION,
200 .cBuffers = static_cast<unsigned long>(output_buffers.size()),
201 .pBuffers = output_buffers.data(),
202 };
203 ASSERT_OR_EXECUTE_MSG(
204 input_buffers[0].cbBuffer == ciphertext_read_buf_.size(),
205 { return ResultInternalError; }, "read buffer too large");
206
207 bool initial_call_done = handshake_state_ != HandshakeState::Initial;
208 if (initial_call_done) {
209 LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext",
210 ciphertext_read_buf_.size());
211 }
212
213 SECURITY_STATUS ret =
214 InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt_ : nullptr,
215 // Caller ensured we have set a hostname:
216 const_cast<char*>(hostname_.value().c_str()), req,
217 0, // Reserved1
218 0, // TargetDataRep not used with Schannel
219 initial_call_done ? &input_desc : nullptr,
220 0, // Reserved2
221 initial_call_done ? nullptr : &ctxt_, &output_desc, &attr,
222 nullptr); // ptsExpiry
223
224 if (output_buffers[0].pvBuffer) {
225 std::span span(static_cast<u8*>(output_buffers[0].pvBuffer),
226 output_buffers[0].cbBuffer);
227 ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), span.begin(), span.end());
228 FreeContextBuffer(output_buffers[0].pvBuffer);
229 }
230
231 if (output_buffers[1].pvBuffer) {
232 std::span span(static_cast<u8*>(output_buffers[1].pvBuffer),
233 output_buffers[1].cbBuffer);
234 // The documentation doesn't explain what format this data is in.
235 LOG_DEBUG(Service_SSL, "Got a {}-byte alert buffer: {}", span.size(),
236 Common::HexToString(span));
237 }
238
239 switch (ret) {
240 case SEC_I_CONTINUE_NEEDED:
241 LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED");
242 if (input_buffers[1].BufferType == SECBUFFER_EXTRA) {
243 LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer);
244 ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf_.size());
245 ciphertext_read_buf_.erase(ciphertext_read_buf_.begin(),
246 ciphertext_read_buf_.end() - input_buffers[1].cbBuffer);
247 } else {
248 ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY);
249 ciphertext_read_buf_.clear();
250 }
251 handshake_state_ = HandshakeState::ContinueNeeded;
252 return ResultSuccess;
253 case SEC_E_INCOMPLETE_MESSAGE:
254 LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE");
255 ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING);
256 read_buf_fill_size_ = input_buffers[1].cbBuffer;
257 handshake_state_ = HandshakeState::IncompleteMessage;
258 return ResultSuccess;
259 case SEC_E_OK:
260 LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK");
261 ciphertext_read_buf_.clear();
262 handshake_state_ = HandshakeState::DoneAfterFlush;
263 return GrabStreamSizes();
264 default:
265 LOG_ERROR(Service_SSL,
266 "InitializeSecurityContext failed (probably certificate/protocol issue): {}",
267 Common::NativeErrorToString(ret));
268 handshake_state_ = HandshakeState::Error;
269 return ResultInternalError;
270 }
271 }
272
273 Result GrabStreamSizes() {
274 SECURITY_STATUS ret =
275 QueryContextAttributes(&ctxt_, SECPKG_ATTR_STREAM_SIZES, &stream_sizes_);
276 if (ret != SEC_E_OK) {
277 LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}",
278 Common::NativeErrorToString(ret));
279 handshake_state_ = HandshakeState::Error;
280 return ResultInternalError;
281 }
282 return ResultSuccess;
283 }
284
285 ResultVal<size_t> Read(std::span<u8> data) override {
286 if (handshake_state_ != HandshakeState::Connected) {
287 LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake");
288 return ResultInternalError;
289 }
290 if (data.size() == 0 || got_read_eof_) {
291 return size_t(0);
292 }
293 while (1) {
294 if (!cleartext_read_buf_.empty()) {
295 size_t read_size = std::min(cleartext_read_buf_.size(), data.size());
296 std::memcpy(data.data(), cleartext_read_buf_.data(), read_size);
297 cleartext_read_buf_.erase(cleartext_read_buf_.begin(),
298 cleartext_read_buf_.begin() + read_size);
299 return read_size;
300 }
301 if (!ciphertext_read_buf_.empty()) {
302 std::array<SecBuffer, 5> buffers{{
303 {
304 .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf_.size()),
305 .BufferType = SECBUFFER_DATA,
306 .pvBuffer = ciphertext_read_buf_.data(),
307 },
308 {
309 .BufferType = SECBUFFER_EMPTY,
310 },
311 {
312 .BufferType = SECBUFFER_EMPTY,
313 },
314 {
315 .BufferType = SECBUFFER_EMPTY,
316 },
317 }};
318 ASSERT_OR_EXECUTE_MSG(
319 buffers[0].cbBuffer == ciphertext_read_buf_.size(),
320 { return ResultInternalError; }, "read buffer too large");
321 SecBufferDesc desc{
322 .ulVersion = SECBUFFER_VERSION,
323 .cBuffers = static_cast<unsigned long>(buffers.size()),
324 .pBuffers = buffers.data(),
325 };
326 SECURITY_STATUS ret =
327 DecryptMessage(&ctxt_, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr);
328 switch (ret) {
329 case SEC_E_OK:
330 ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER,
331 { return ResultInternalError; });
332 ASSERT_OR_EXECUTE(buffers[1].BufferType == SECBUFFER_DATA,
333 { return ResultInternalError; });
334 ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER,
335 { return ResultInternalError; });
336 cleartext_read_buf_.assign(static_cast<u8*>(buffers[1].pvBuffer),
337 static_cast<u8*>(buffers[1].pvBuffer) +
338 buffers[1].cbBuffer);
339 if (buffers[3].BufferType == SECBUFFER_EXTRA) {
340 ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf_.size());
341 ciphertext_read_buf_.erase(ciphertext_read_buf_.begin(),
342 ciphertext_read_buf_.end() -
343 buffers[3].cbBuffer);
344 } else {
345 ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY);
346 ciphertext_read_buf_.clear();
347 }
348 continue;
349 case SEC_E_INCOMPLETE_MESSAGE:
350 break;
351 case SEC_I_CONTEXT_EXPIRED:
352 // Server hung up by sending close_notify.
353 got_read_eof_ = true;
354 return size_t(0);
355 default:
356 LOG_ERROR(Service_SSL, "DecryptMessage failed: {}",
357 Common::NativeErrorToString(ret));
358 return ResultInternalError;
359 }
360 }
361 Result r = FillCiphertextReadBuf();
362 if (r != ResultSuccess) {
363 return r;
364 }
365 if (ciphertext_read_buf_.empty()) {
366 got_read_eof_ = true;
367 return size_t(0);
368 }
369 }
370 }
371
372 ResultVal<size_t> Write(std::span<const u8> data) override {
373 if (handshake_state_ != HandshakeState::Connected) {
374 LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake");
375 return ResultInternalError;
376 }
377 if (data.size() == 0) {
378 return size_t(0);
379 }
380 data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes_.cbMaximumMessage));
381 if (!cleartext_write_buf_.empty()) {
382 // Already in the middle of a write. It wouldn't make sense to not
383 // finish sending the entire buffer since TLS has
384 // header/MAC/padding/etc.
385 if (data.size() != cleartext_write_buf_.size() ||
386 std::memcmp(data.data(), cleartext_write_buf_.data(), data.size())) {
387 LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer");
388 return ResultInternalError;
389 }
390 return WriteAlreadyEncryptedData();
391 } else {
392 cleartext_write_buf_.assign(data.begin(), data.end());
393 }
394
395 std::vector<u8> header_buf(stream_sizes_.cbHeader, 0);
396 std::vector<u8> tmp_data_buf = cleartext_write_buf_;
397 std::vector<u8> trailer_buf(stream_sizes_.cbTrailer, 0);
398
399 std::array<SecBuffer, 3> buffers{{
400 {
401 .cbBuffer = stream_sizes_.cbHeader,
402 .BufferType = SECBUFFER_STREAM_HEADER,
403 .pvBuffer = header_buf.data(),
404 },
405 {
406 .cbBuffer = static_cast<unsigned long>(tmp_data_buf.size()),
407 .BufferType = SECBUFFER_DATA,
408 .pvBuffer = tmp_data_buf.data(),
409 },
410 {
411 .cbBuffer = stream_sizes_.cbTrailer,
412 .BufferType = SECBUFFER_STREAM_TRAILER,
413 .pvBuffer = trailer_buf.data(),
414 },
415 }};
416 ASSERT_OR_EXECUTE_MSG(
417 buffers[1].cbBuffer == tmp_data_buf.size(), { return ResultInternalError; },
418 "temp buffer too large");
419 SecBufferDesc desc{
420 .ulVersion = SECBUFFER_VERSION,
421 .cBuffers = static_cast<unsigned long>(buffers.size()),
422 .pBuffers = buffers.data(),
423 };
424
425 SECURITY_STATUS ret = EncryptMessage(&ctxt_, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0);
426 if (ret != SEC_E_OK) {
427 LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret));
428 return ResultInternalError;
429 }
430 ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), header_buf.begin(),
431 header_buf.end());
432 ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), tmp_data_buf.begin(),
433 tmp_data_buf.end());
434 ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), trailer_buf.begin(),
435 trailer_buf.end());
436 return WriteAlreadyEncryptedData();
437 }
438
439 ResultVal<size_t> WriteAlreadyEncryptedData() {
440 Result r = FlushCiphertextWriteBuf();
441 if (r != ResultSuccess) {
442 return r;
443 }
444 // write buf is empty
445 size_t cleartext_bytes_written = cleartext_write_buf_.size();
446 cleartext_write_buf_.clear();
447 return cleartext_bytes_written;
448 }
449
450 ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
451 PCCERT_CONTEXT returned_cert = nullptr;
452 SECURITY_STATUS ret =
453 QueryContextAttributes(&ctxt_, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert);
454 if (ret != SEC_E_OK) {
455 LOG_ERROR(Service_SSL,
456 "QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}",
457 Common::NativeErrorToString(ret));
458 return ResultInternalError;
459 }
460 PCCERT_CONTEXT some_cert = nullptr;
461 std::vector<std::vector<u8>> certs;
462 while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) {
463 certs.emplace_back(static_cast<u8*>(some_cert->pbCertEncoded),
464 static_cast<u8*>(some_cert->pbCertEncoded) +
465 some_cert->cbCertEncoded);
466 }
467 std::reverse(certs.begin(),
468 certs.end()); // Windows returns certs in reverse order from what we want
469 CertFreeCertificateContext(returned_cert);
470 return certs;
471 }
472
473 ~SSLConnectionBackendSchannel() {
474 if (handshake_state_ != HandshakeState::Initial) {
475 DeleteSecurityContext(&ctxt_);
476 }
477 }
478
479 enum class HandshakeState {
480 // Haven't called anything yet.
481 Initial,
482 // `SEC_I_CONTINUE_NEEDED` was returned by
483 // `InitializeSecurityContext`; must finish sending data (if any) in
484 // the write buffer, then read at least one byte before calling
485 // `InitializeSecurityContext` again.
486 ContinueNeeded,
487 // `SEC_E_INCOMPLETE_MESSAGE` was returned by
488 // `InitializeSecurityContext`; hopefully the write buffer is empty;
489 // must read at least one byte before calling
490 // `InitializeSecurityContext` again.
491 IncompleteMessage,
492 // `SEC_E_OK` was returned by `InitializeSecurityContext`; must
493 // finish sending data in the write buffer before having `DoHandshake`
494 // report success.
495 DoneAfterFlush,
496 // We finished the above and are now connected. At this point, writing
497 // and reading are separate 'state machines' represented by the
498 // nonemptiness of the ciphertext and cleartext read and write buffers.
499 Connected,
500 // Another error was returned and we shouldn't allow initialization
501 // to continue.
502 Error,
503 } handshake_state_ = HandshakeState::Initial;
504
505 CtxtHandle ctxt_;
506 SecPkgContext_StreamSizes stream_sizes_;
507
508 std::shared_ptr<Network::SocketBase> socket_;
509 std::optional<std::string> hostname_;
510
511 std::vector<u8> ciphertext_read_buf_;
512 std::vector<u8> ciphertext_write_buf_;
513 std::vector<u8> cleartext_read_buf_;
514 std::vector<u8> cleartext_write_buf_;
515
516 bool got_read_eof_ = false;
517 size_t read_buf_fill_size_ = 0;
518};
519
520ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
521 auto conn = std::make_unique<SSLConnectionBackendSchannel>();
522 Result res = conn->Init();
523 if (res.IsFailure()) {
524 return res;
525 }
526 return conn;
527}
528
529} // namespace Service::SSL
diff --git a/src/core/internal_network/network.cpp b/src/core/internal_network/network.cpp
index 75ac10a9c..39381e06e 100644
--- a/src/core/internal_network/network.cpp
+++ b/src/core/internal_network/network.cpp
@@ -121,6 +121,8 @@ Errno TranslateNativeError(int e) {
121 return Errno::MSGSIZE; 121 return Errno::MSGSIZE;
122 case WSAETIMEDOUT: 122 case WSAETIMEDOUT:
123 return Errno::TIMEDOUT; 123 return Errno::TIMEDOUT;
124 case WSAEINPROGRESS:
125 return Errno::INPROGRESS;
124 default: 126 default:
125 UNIMPLEMENTED_MSG("Unimplemented errno={}", e); 127 UNIMPLEMENTED_MSG("Unimplemented errno={}", e);
126 return Errno::OTHER; 128 return Errno::OTHER;
@@ -195,6 +197,8 @@ bool EnableNonBlock(int fd, bool enable) {
195 197
196Errno TranslateNativeError(int e) { 198Errno TranslateNativeError(int e) {
197 switch (e) { 199 switch (e) {
200 case 0:
201 return Errno::SUCCESS;
198 case EBADF: 202 case EBADF:
199 return Errno::BADF; 203 return Errno::BADF;
200 case EINVAL: 204 case EINVAL:
@@ -219,8 +223,10 @@ Errno TranslateNativeError(int e) {
219 return Errno::MSGSIZE; 223 return Errno::MSGSIZE;
220 case ETIMEDOUT: 224 case ETIMEDOUT:
221 return Errno::TIMEDOUT; 225 return Errno::TIMEDOUT;
226 case EINPROGRESS:
227 return Errno::INPROGRESS;
222 default: 228 default:
223 UNIMPLEMENTED_MSG("Unimplemented errno={}", e); 229 UNIMPLEMENTED_MSG("Unimplemented errno={} ({})", e, strerror(e));
224 return Errno::OTHER; 230 return Errno::OTHER;
225 } 231 }
226} 232}
@@ -234,15 +240,84 @@ Errno GetAndLogLastError() {
234 int e = errno; 240 int e = errno;
235#endif 241#endif
236 const Errno err = TranslateNativeError(e); 242 const Errno err = TranslateNativeError(e);
237 if (err == Errno::AGAIN || err == Errno::TIMEDOUT) { 243 if (err == Errno::AGAIN || err == Errno::TIMEDOUT || err == Errno::INPROGRESS) {
244 // These happen during normal operation, so only log them at debug level.
245 LOG_DEBUG(Network, "Socket operation error: {}", Common::NativeErrorToString(e));
238 return err; 246 return err;
239 } 247 }
240 LOG_ERROR(Network, "Socket operation error: {}", Common::NativeErrorToString(e)); 248 LOG_ERROR(Network, "Socket operation error: {}", Common::NativeErrorToString(e));
241 return err; 249 return err;
242} 250}
243 251
244int TranslateDomain(Domain domain) { 252GetAddrInfoError TranslateGetAddrInfoErrorFromNative(int gai_err) {
253 switch (gai_err) {
254 case 0:
255 return GetAddrInfoError::SUCCESS;
256#ifdef EAI_ADDRFAMILY
257 case EAI_ADDRFAMILY:
258 return GetAddrInfoError::ADDRFAMILY;
259#endif
260 case EAI_AGAIN:
261 return GetAddrInfoError::AGAIN;
262 case EAI_BADFLAGS:
263 return GetAddrInfoError::BADFLAGS;
264 case EAI_FAIL:
265 return GetAddrInfoError::FAIL;
266 case EAI_FAMILY:
267 return GetAddrInfoError::FAMILY;
268 case EAI_MEMORY:
269 return GetAddrInfoError::MEMORY;
270 case EAI_NONAME:
271 return GetAddrInfoError::NONAME;
272 case EAI_SERVICE:
273 return GetAddrInfoError::SERVICE;
274 case EAI_SOCKTYPE:
275 return GetAddrInfoError::SOCKTYPE;
276 // These codes may not be defined on all systems:
277#ifdef EAI_SYSTEM
278 case EAI_SYSTEM:
279 return GetAddrInfoError::SYSTEM;
280#endif
281#ifdef EAI_BADHINTS
282 case EAI_BADHINTS:
283 return GetAddrInfoError::BADHINTS;
284#endif
285#ifdef EAI_PROTOCOL
286 case EAI_PROTOCOL:
287 return GetAddrInfoError::PROTOCOL;
288#endif
289#ifdef EAI_OVERFLOW
290 case EAI_OVERFLOW:
291 return GetAddrInfoError::OVERFLOW_;
292#endif
293 default:
294#ifdef EAI_NODATA
295 // This can't be a case statement because it would create a duplicate
296 // case on Windows where EAI_NODATA is an alias for EAI_NONAME.
297 if (gai_err == EAI_NODATA) {
298 return GetAddrInfoError::NODATA;
299 }
300#endif
301 return GetAddrInfoError::OTHER;
302 }
303}
304
305Domain TranslateDomainFromNative(int domain) {
245 switch (domain) { 306 switch (domain) {
307 case 0:
308 return Domain::Unspecified;
309 case AF_INET:
310 return Domain::INET;
311 default:
312 UNIMPLEMENTED_MSG("Unhandled domain={}", domain);
313 return Domain::INET;
314 }
315}
316
317int TranslateDomainToNative(Domain domain) {
318 switch (domain) {
319 case Domain::Unspecified:
320 return 0;
246 case Domain::INET: 321 case Domain::INET:
247 return AF_INET; 322 return AF_INET;
248 default: 323 default:
@@ -251,20 +326,58 @@ int TranslateDomain(Domain domain) {
251 } 326 }
252} 327}
253 328
254int TranslateType(Type type) { 329Type TranslateTypeFromNative(int type) {
330 switch (type) {
331 case 0:
332 return Type::Unspecified;
333 case SOCK_STREAM:
334 return Type::STREAM;
335 case SOCK_DGRAM:
336 return Type::DGRAM;
337 case SOCK_RAW:
338 return Type::RAW;
339 case SOCK_SEQPACKET:
340 return Type::SEQPACKET;
341 default:
342 UNIMPLEMENTED_MSG("Unimplemented type={}", type);
343 return Type::STREAM;
344 }
345}
346
347int TranslateTypeToNative(Type type) {
255 switch (type) { 348 switch (type) {
349 case Type::Unspecified:
350 return 0;
256 case Type::STREAM: 351 case Type::STREAM:
257 return SOCK_STREAM; 352 return SOCK_STREAM;
258 case Type::DGRAM: 353 case Type::DGRAM:
259 return SOCK_DGRAM; 354 return SOCK_DGRAM;
355 case Type::RAW:
356 return SOCK_RAW;
260 default: 357 default:
261 UNIMPLEMENTED_MSG("Unimplemented type={}", type); 358 UNIMPLEMENTED_MSG("Unimplemented type={}", type);
262 return 0; 359 return 0;
263 } 360 }
264} 361}
265 362
266int TranslateProtocol(Protocol protocol) { 363Protocol TranslateProtocolFromNative(int protocol) {
364 switch (protocol) {
365 case 0:
366 return Protocol::Unspecified;
367 case IPPROTO_TCP:
368 return Protocol::TCP;
369 case IPPROTO_UDP:
370 return Protocol::UDP;
371 default:
372 UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol);
373 return Protocol::Unspecified;
374 }
375}
376
377int TranslateProtocolToNative(Protocol protocol) {
267 switch (protocol) { 378 switch (protocol) {
379 case Protocol::Unspecified:
380 return 0;
268 case Protocol::TCP: 381 case Protocol::TCP:
269 return IPPROTO_TCP; 382 return IPPROTO_TCP;
270 case Protocol::UDP: 383 case Protocol::UDP:
@@ -275,21 +388,10 @@ int TranslateProtocol(Protocol protocol) {
275 } 388 }
276} 389}
277 390
278SockAddrIn TranslateToSockAddrIn(sockaddr input_) { 391SockAddrIn TranslateToSockAddrIn(sockaddr_in input, size_t input_len) {
279 sockaddr_in input;
280 std::memcpy(&input, &input_, sizeof(input));
281
282 SockAddrIn result; 392 SockAddrIn result;
283 393
284 switch (input.sin_family) { 394 result.family = TranslateDomainFromNative(input.sin_family);
285 case AF_INET:
286 result.family = Domain::INET;
287 break;
288 default:
289 UNIMPLEMENTED_MSG("Unhandled sockaddr family={}", input.sin_family);
290 result.family = Domain::INET;
291 break;
292 }
293 395
294 result.portno = ntohs(input.sin_port); 396 result.portno = ntohs(input.sin_port);
295 397
@@ -301,22 +403,28 @@ SockAddrIn TranslateToSockAddrIn(sockaddr input_) {
301short TranslatePollEvents(PollEvents events) { 403short TranslatePollEvents(PollEvents events) {
302 short result = 0; 404 short result = 0;
303 405
304 if (True(events & PollEvents::In)) { 406 const auto translate = [&result, &events](PollEvents guest, short host) {
305 events &= ~PollEvents::In; 407 if (True(events & guest)) {
306 result |= POLLIN; 408 events &= ~guest;
307 } 409 result |= host;
308 if (True(events & PollEvents::Pri)) { 410 }
309 events &= ~PollEvents::Pri; 411 };
412
413 translate(PollEvents::In, POLLIN);
414 translate(PollEvents::Pri, POLLPRI);
415 translate(PollEvents::Out, POLLOUT);
416 translate(PollEvents::Err, POLLERR);
417 translate(PollEvents::Hup, POLLHUP);
418 translate(PollEvents::Nval, POLLNVAL);
419 translate(PollEvents::RdNorm, POLLRDNORM);
420 translate(PollEvents::RdBand, POLLRDBAND);
421 translate(PollEvents::WrBand, POLLWRBAND);
422
310#ifdef _WIN32 423#ifdef _WIN32
424 if (True(events & PollEvents::Pri)) {
311 LOG_WARNING(Service, "Winsock doesn't support POLLPRI"); 425 LOG_WARNING(Service, "Winsock doesn't support POLLPRI");
312#else
313 result |= POLLPRI;
314#endif
315 }
316 if (True(events & PollEvents::Out)) {
317 events &= ~PollEvents::Out;
318 result |= POLLOUT;
319 } 426 }
427#endif
320 428
321 UNIMPLEMENTED_IF_MSG((u16)events != 0, "Unhandled guest events=0x{:x}", (u16)events); 429 UNIMPLEMENTED_IF_MSG((u16)events != 0, "Unhandled guest events=0x{:x}", (u16)events);
322 430
@@ -337,6 +445,10 @@ PollEvents TranslatePollRevents(short revents) {
337 translate(POLLOUT, PollEvents::Out); 445 translate(POLLOUT, PollEvents::Out);
338 translate(POLLERR, PollEvents::Err); 446 translate(POLLERR, PollEvents::Err);
339 translate(POLLHUP, PollEvents::Hup); 447 translate(POLLHUP, PollEvents::Hup);
448 translate(POLLNVAL, PollEvents::Nval);
449 translate(POLLRDNORM, PollEvents::RdNorm);
450 translate(POLLRDBAND, PollEvents::RdBand);
451 translate(POLLWRBAND, PollEvents::WrBand);
340 452
341 UNIMPLEMENTED_IF_MSG(revents != 0, "Unhandled host revents=0x{:x}", revents); 453 UNIMPLEMENTED_IF_MSG(revents != 0, "Unhandled host revents=0x{:x}", revents);
342 454
@@ -360,12 +472,53 @@ std::optional<IPv4Address> GetHostIPv4Address() {
360 return {}; 472 return {};
361 } 473 }
362 474
363 std::array<char, 16> ip_addr = {};
364 ASSERT(inet_ntop(AF_INET, &network_interface->ip_address, ip_addr.data(), sizeof(ip_addr)) !=
365 nullptr);
366 return TranslateIPv4(network_interface->ip_address); 475 return TranslateIPv4(network_interface->ip_address);
367} 476}
368 477
478std::string IPv4AddressToString(IPv4Address ip_addr) {
479 std::array<char, INET_ADDRSTRLEN> buf = {};
480 ASSERT(inet_ntop(AF_INET, &ip_addr, buf.data(), sizeof(buf)) == buf.data());
481 return std::string(buf.data());
482}
483
484u32 IPv4AddressToInteger(IPv4Address ip_addr) {
485 return static_cast<u32>(ip_addr[0]) << 24 | static_cast<u32>(ip_addr[1]) << 16 |
486 static_cast<u32>(ip_addr[2]) << 8 | static_cast<u32>(ip_addr[3]);
487}
488
489#undef GetAddrInfo // Windows defines it as a macro
490
491Common::Expected<std::vector<AddrInfo>, GetAddrInfoError> GetAddrInfo(
492 const std::string& host, const std::optional<std::string>& service) {
493 addrinfo hints{};
494 hints.ai_family = AF_INET; // Switch only supports IPv4.
495 addrinfo* addrinfo;
496 s32 gai_err = getaddrinfo(host.c_str(), service.has_value() ? service->c_str() : nullptr,
497 &hints, &addrinfo);
498 if (gai_err != 0) {
499 return Common::Unexpected(TranslateGetAddrInfoErrorFromNative(gai_err));
500 }
501 std::vector<AddrInfo> ret;
502 for (auto* current = addrinfo; current; current = current->ai_next) {
503 // We should only get AF_INET results due to the hints value.
504 ASSERT_OR_EXECUTE(addrinfo->ai_family == AF_INET &&
505 addrinfo->ai_addrlen == sizeof(sockaddr_in),
506 continue;);
507
508 AddrInfo& out = ret.emplace_back();
509 out.family = TranslateDomainFromNative(current->ai_family);
510 out.socket_type = TranslateTypeFromNative(current->ai_socktype);
511 out.protocol = TranslateProtocolFromNative(current->ai_protocol);
512 out.addr = TranslateToSockAddrIn(*reinterpret_cast<sockaddr_in*>(current->ai_addr),
513 current->ai_addrlen);
514 if (current->ai_canonname != nullptr) {
515 out.canon_name = current->ai_canonname;
516 }
517 }
518 freeaddrinfo(addrinfo);
519 return ret;
520}
521
369std::pair<s32, Errno> Poll(std::vector<PollFD>& pollfds, s32 timeout) { 522std::pair<s32, Errno> Poll(std::vector<PollFD>& pollfds, s32 timeout) {
370 const size_t num = pollfds.size(); 523 const size_t num = pollfds.size();
371 524
@@ -411,6 +564,18 @@ Socket::Socket(Socket&& rhs) noexcept {
411} 564}
412 565
413template <typename T> 566template <typename T>
567std::pair<T, Errno> Socket::GetSockOpt(SOCKET fd_, int option) {
568 T value{};
569 socklen_t len = sizeof(value);
570 const int result = getsockopt(fd_, SOL_SOCKET, option, reinterpret_cast<char*>(&value), &len);
571 if (result != SOCKET_ERROR) {
572 ASSERT(len == sizeof(value));
573 return {value, Errno::SUCCESS};
574 }
575 return {value, GetAndLogLastError()};
576}
577
578template <typename T>
414Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) { 579Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) {
415 const int result = 580 const int result =
416 setsockopt(fd_, SOL_SOCKET, option, reinterpret_cast<const char*>(&value), sizeof(value)); 581 setsockopt(fd_, SOL_SOCKET, option, reinterpret_cast<const char*>(&value), sizeof(value));
@@ -421,7 +586,8 @@ Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) {
421} 586}
422 587
423Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) { 588Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) {
424 fd = socket(TranslateDomain(domain), TranslateType(type), TranslateProtocol(protocol)); 589 fd = socket(TranslateDomainToNative(domain), TranslateTypeToNative(type),
590 TranslateProtocolToNative(protocol));
425 if (fd != INVALID_SOCKET) { 591 if (fd != INVALID_SOCKET) {
426 return Errno::SUCCESS; 592 return Errno::SUCCESS;
427 } 593 }
@@ -430,19 +596,17 @@ Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) {
430} 596}
431 597
432std::pair<SocketBase::AcceptResult, Errno> Socket::Accept() { 598std::pair<SocketBase::AcceptResult, Errno> Socket::Accept() {
433 sockaddr addr; 599 sockaddr_in addr;
434 socklen_t addrlen = sizeof(addr); 600 socklen_t addrlen = sizeof(addr);
435 const SOCKET new_socket = accept(fd, &addr, &addrlen); 601 const SOCKET new_socket = accept(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen);
436 602
437 if (new_socket == INVALID_SOCKET) { 603 if (new_socket == INVALID_SOCKET) {
438 return {AcceptResult{}, GetAndLogLastError()}; 604 return {AcceptResult{}, GetAndLogLastError()};
439 } 605 }
440 606
441 ASSERT(addrlen == sizeof(sockaddr_in));
442
443 AcceptResult result{ 607 AcceptResult result{
444 .socket = std::make_unique<Socket>(new_socket), 608 .socket = std::make_unique<Socket>(new_socket),
445 .sockaddr_in = TranslateToSockAddrIn(addr), 609 .sockaddr_in = TranslateToSockAddrIn(addr, addrlen),
446 }; 610 };
447 611
448 return {std::move(result), Errno::SUCCESS}; 612 return {std::move(result), Errno::SUCCESS};
@@ -458,25 +622,23 @@ Errno Socket::Connect(SockAddrIn addr_in) {
458} 622}
459 623
460std::pair<SockAddrIn, Errno> Socket::GetPeerName() { 624std::pair<SockAddrIn, Errno> Socket::GetPeerName() {
461 sockaddr addr; 625 sockaddr_in addr;
462 socklen_t addrlen = sizeof(addr); 626 socklen_t addrlen = sizeof(addr);
463 if (getpeername(fd, &addr, &addrlen) == SOCKET_ERROR) { 627 if (getpeername(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen) == SOCKET_ERROR) {
464 return {SockAddrIn{}, GetAndLogLastError()}; 628 return {SockAddrIn{}, GetAndLogLastError()};
465 } 629 }
466 630
467 ASSERT(addrlen == sizeof(sockaddr_in)); 631 return {TranslateToSockAddrIn(addr, addrlen), Errno::SUCCESS};
468 return {TranslateToSockAddrIn(addr), Errno::SUCCESS};
469} 632}
470 633
471std::pair<SockAddrIn, Errno> Socket::GetSockName() { 634std::pair<SockAddrIn, Errno> Socket::GetSockName() {
472 sockaddr addr; 635 sockaddr_in addr;
473 socklen_t addrlen = sizeof(addr); 636 socklen_t addrlen = sizeof(addr);
474 if (getsockname(fd, &addr, &addrlen) == SOCKET_ERROR) { 637 if (getsockname(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen) == SOCKET_ERROR) {
475 return {SockAddrIn{}, GetAndLogLastError()}; 638 return {SockAddrIn{}, GetAndLogLastError()};
476 } 639 }
477 640
478 ASSERT(addrlen == sizeof(sockaddr_in)); 641 return {TranslateToSockAddrIn(addr, addrlen), Errno::SUCCESS};
479 return {TranslateToSockAddrIn(addr), Errno::SUCCESS};
480} 642}
481 643
482Errno Socket::Bind(SockAddrIn addr) { 644Errno Socket::Bind(SockAddrIn addr) {
@@ -519,7 +681,7 @@ Errno Socket::Shutdown(ShutdownHow how) {
519 return GetAndLogLastError(); 681 return GetAndLogLastError();
520} 682}
521 683
522std::pair<s32, Errno> Socket::Recv(int flags, std::vector<u8>& message) { 684std::pair<s32, Errno> Socket::Recv(int flags, std::span<u8> message) {
523 ASSERT(flags == 0); 685 ASSERT(flags == 0);
524 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); 686 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
525 687
@@ -532,21 +694,20 @@ std::pair<s32, Errno> Socket::Recv(int flags, std::vector<u8>& message) {
532 return {-1, GetAndLogLastError()}; 694 return {-1, GetAndLogLastError()};
533} 695}
534 696
535std::pair<s32, Errno> Socket::RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) { 697std::pair<s32, Errno> Socket::RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) {
536 ASSERT(flags == 0); 698 ASSERT(flags == 0);
537 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); 699 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
538 700
539 sockaddr addr_in{}; 701 sockaddr_in addr_in{};
540 socklen_t addrlen = sizeof(addr_in); 702 socklen_t addrlen = sizeof(addr_in);
541 socklen_t* const p_addrlen = addr ? &addrlen : nullptr; 703 socklen_t* const p_addrlen = addr ? &addrlen : nullptr;
542 sockaddr* const p_addr_in = addr ? &addr_in : nullptr; 704 sockaddr* const p_addr_in = addr ? reinterpret_cast<sockaddr*>(&addr_in) : nullptr;
543 705
544 const auto result = recvfrom(fd, reinterpret_cast<char*>(message.data()), 706 const auto result = recvfrom(fd, reinterpret_cast<char*>(message.data()),
545 static_cast<int>(message.size()), 0, p_addr_in, p_addrlen); 707 static_cast<int>(message.size()), 0, p_addr_in, p_addrlen);
546 if (result != SOCKET_ERROR) { 708 if (result != SOCKET_ERROR) {
547 if (addr) { 709 if (addr) {
548 ASSERT(addrlen == sizeof(addr_in)); 710 *addr = TranslateToSockAddrIn(addr_in, addrlen);
549 *addr = TranslateToSockAddrIn(addr_in);
550 } 711 }
551 return {static_cast<s32>(result), Errno::SUCCESS}; 712 return {static_cast<s32>(result), Errno::SUCCESS};
552 } 713 }
@@ -597,6 +758,11 @@ Errno Socket::Close() {
597 return Errno::SUCCESS; 758 return Errno::SUCCESS;
598} 759}
599 760
761std::pair<Errno, Errno> Socket::GetPendingError() {
762 auto [pending_err, getsockopt_err] = GetSockOpt<int>(fd, SO_ERROR);
763 return {TranslateNativeError(pending_err), getsockopt_err};
764}
765
600Errno Socket::SetLinger(bool enable, u32 linger) { 766Errno Socket::SetLinger(bool enable, u32 linger) {
601 return SetSockOpt(fd, SO_LINGER, MakeLinger(enable, linger)); 767 return SetSockOpt(fd, SO_LINGER, MakeLinger(enable, linger));
602} 768}
diff --git a/src/core/internal_network/network.h b/src/core/internal_network/network.h
index 1e09a007a..96319bfc8 100644
--- a/src/core/internal_network/network.h
+++ b/src/core/internal_network/network.h
@@ -16,6 +16,11 @@
16#include <netinet/in.h> 16#include <netinet/in.h>
17#endif 17#endif
18 18
19namespace Common {
20template <typename T, typename E>
21class Expected;
22}
23
19namespace Network { 24namespace Network {
20 25
21class SocketBase; 26class SocketBase;
@@ -36,6 +41,26 @@ enum class Errno {
36 NETUNREACH, 41 NETUNREACH,
37 TIMEDOUT, 42 TIMEDOUT,
38 MSGSIZE, 43 MSGSIZE,
44 INPROGRESS,
45 OTHER,
46};
47
48enum class GetAddrInfoError {
49 SUCCESS,
50 ADDRFAMILY,
51 AGAIN,
52 BADFLAGS,
53 FAIL,
54 FAMILY,
55 MEMORY,
56 NODATA,
57 NONAME,
58 SERVICE,
59 SOCKTYPE,
60 SYSTEM,
61 BADHINTS,
62 PROTOCOL,
63 OVERFLOW_,
39 OTHER, 64 OTHER,
40}; 65};
41 66
@@ -49,6 +74,9 @@ enum class PollEvents : u16 {
49 Err = 1 << 3, 74 Err = 1 << 3,
50 Hup = 1 << 4, 75 Hup = 1 << 4,
51 Nval = 1 << 5, 76 Nval = 1 << 5,
77 RdNorm = 1 << 6,
78 RdBand = 1 << 7,
79 WrBand = 1 << 8,
52}; 80};
53 81
54DECLARE_ENUM_FLAG_OPERATORS(PollEvents); 82DECLARE_ENUM_FLAG_OPERATORS(PollEvents);
@@ -82,4 +110,10 @@ constexpr IPv4Address TranslateIPv4(in_addr addr) {
82/// @return human ordered IPv4 address (e.g. 192.168.0.1) as an array 110/// @return human ordered IPv4 address (e.g. 192.168.0.1) as an array
83std::optional<IPv4Address> GetHostIPv4Address(); 111std::optional<IPv4Address> GetHostIPv4Address();
84 112
113std::string IPv4AddressToString(IPv4Address ip_addr);
114u32 IPv4AddressToInteger(IPv4Address ip_addr);
115
116Common::Expected<std::vector<AddrInfo>, GetAddrInfoError> GetAddrInfo(
117 const std::string& host, const std::optional<std::string>& service);
118
85} // namespace Network 119} // namespace Network
diff --git a/src/core/internal_network/socket_proxy.cpp b/src/core/internal_network/socket_proxy.cpp
index 7a77171c2..44e9e3093 100644
--- a/src/core/internal_network/socket_proxy.cpp
+++ b/src/core/internal_network/socket_proxy.cpp
@@ -98,7 +98,7 @@ Errno ProxySocket::Shutdown(ShutdownHow how) {
98 return Errno::SUCCESS; 98 return Errno::SUCCESS;
99} 99}
100 100
101std::pair<s32, Errno> ProxySocket::Recv(int flags, std::vector<u8>& message) { 101std::pair<s32, Errno> ProxySocket::Recv(int flags, std::span<u8> message) {
102 LOG_WARNING(Network, "(STUBBED) called"); 102 LOG_WARNING(Network, "(STUBBED) called");
103 ASSERT(flags == 0); 103 ASSERT(flags == 0);
104 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); 104 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
@@ -106,7 +106,7 @@ std::pair<s32, Errno> ProxySocket::Recv(int flags, std::vector<u8>& message) {
106 return {static_cast<s32>(0), Errno::SUCCESS}; 106 return {static_cast<s32>(0), Errno::SUCCESS};
107} 107}
108 108
109std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) { 109std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) {
110 ASSERT(flags == 0); 110 ASSERT(flags == 0);
111 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); 111 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
112 112
@@ -140,8 +140,8 @@ std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::vector<u8>& message,
140 } 140 }
141} 141}
142 142
143std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& message, 143std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::span<u8> message, SockAddrIn* addr,
144 SockAddrIn* addr, std::size_t max_length) { 144 std::size_t max_length) {
145 ProxyPacket& packet = received_packets.front(); 145 ProxyPacket& packet = received_packets.front();
146 if (addr) { 146 if (addr) {
147 addr->family = Domain::INET; 147 addr->family = Domain::INET;
@@ -153,10 +153,7 @@ std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& mes
153 std::size_t read_bytes; 153 std::size_t read_bytes;
154 if (packet.data.size() > max_length) { 154 if (packet.data.size() > max_length) {
155 read_bytes = max_length; 155 read_bytes = max_length;
156 message.clear(); 156 memcpy(message.data(), packet.data.data(), max_length);
157 std::copy(packet.data.begin(), packet.data.begin() + read_bytes,
158 std::back_inserter(message));
159 message.resize(max_length);
160 157
161 if (protocol == Protocol::UDP) { 158 if (protocol == Protocol::UDP) {
162 if (!peek) { 159 if (!peek) {
@@ -171,9 +168,7 @@ std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& mes
171 } 168 }
172 } else { 169 } else {
173 read_bytes = packet.data.size(); 170 read_bytes = packet.data.size();
174 message.clear(); 171 memcpy(message.data(), packet.data.data(), read_bytes);
175 std::copy(packet.data.begin(), packet.data.end(), std::back_inserter(message));
176 message.resize(max_length);
177 if (!peek) { 172 if (!peek) {
178 received_packets.pop(); 173 received_packets.pop();
179 } 174 }
@@ -293,6 +288,11 @@ Errno ProxySocket::SetNonBlock(bool enable) {
293 return Errno::SUCCESS; 288 return Errno::SUCCESS;
294} 289}
295 290
291std::pair<Errno, Errno> ProxySocket::GetPendingError() {
292 LOG_DEBUG(Network, "(STUBBED) called");
293 return {Errno::SUCCESS, Errno::SUCCESS};
294}
295
296bool ProxySocket::IsOpened() const { 296bool ProxySocket::IsOpened() const {
297 return fd != INVALID_SOCKET; 297 return fd != INVALID_SOCKET;
298} 298}
diff --git a/src/core/internal_network/socket_proxy.h b/src/core/internal_network/socket_proxy.h
index 6e991fa38..e12c413d1 100644
--- a/src/core/internal_network/socket_proxy.h
+++ b/src/core/internal_network/socket_proxy.h
@@ -39,11 +39,11 @@ public:
39 39
40 Errno Shutdown(ShutdownHow how) override; 40 Errno Shutdown(ShutdownHow how) override;
41 41
42 std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) override; 42 std::pair<s32, Errno> Recv(int flags, std::span<u8> message) override;
43 43
44 std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) override; 44 std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) override;
45 45
46 std::pair<s32, Errno> ReceivePacket(int flags, std::vector<u8>& message, SockAddrIn* addr, 46 std::pair<s32, Errno> ReceivePacket(int flags, std::span<u8> message, SockAddrIn* addr,
47 std::size_t max_length); 47 std::size_t max_length);
48 48
49 std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override; 49 std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override;
@@ -74,6 +74,8 @@ public:
74 template <typename T> 74 template <typename T>
75 Errno SetSockOpt(SOCKET fd, int option, T value); 75 Errno SetSockOpt(SOCKET fd, int option, T value);
76 76
77 std::pair<Errno, Errno> GetPendingError() override;
78
77 bool IsOpened() const override; 79 bool IsOpened() const override;
78 80
79private: 81private:
diff --git a/src/core/internal_network/sockets.h b/src/core/internal_network/sockets.h
index 11e479e50..46a53ef79 100644
--- a/src/core/internal_network/sockets.h
+++ b/src/core/internal_network/sockets.h
@@ -59,10 +59,9 @@ public:
59 59
60 virtual Errno Shutdown(ShutdownHow how) = 0; 60 virtual Errno Shutdown(ShutdownHow how) = 0;
61 61
62 virtual std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) = 0; 62 virtual std::pair<s32, Errno> Recv(int flags, std::span<u8> message) = 0;
63 63
64 virtual std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, 64 virtual std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) = 0;
65 SockAddrIn* addr) = 0;
66 65
67 virtual std::pair<s32, Errno> Send(std::span<const u8> message, int flags) = 0; 66 virtual std::pair<s32, Errno> Send(std::span<const u8> message, int flags) = 0;
68 67
@@ -87,6 +86,8 @@ public:
87 86
88 virtual Errno SetNonBlock(bool enable) = 0; 87 virtual Errno SetNonBlock(bool enable) = 0;
89 88
89 virtual std::pair<Errno, Errno> GetPendingError() = 0;
90
90 virtual bool IsOpened() const = 0; 91 virtual bool IsOpened() const = 0;
91 92
92 virtual void HandleProxyPacket(const ProxyPacket& packet) = 0; 93 virtual void HandleProxyPacket(const ProxyPacket& packet) = 0;
@@ -126,9 +127,9 @@ public:
126 127
127 Errno Shutdown(ShutdownHow how) override; 128 Errno Shutdown(ShutdownHow how) override;
128 129
129 std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) override; 130 std::pair<s32, Errno> Recv(int flags, std::span<u8> message) override;
130 131
131 std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) override; 132 std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) override;
132 133
133 std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override; 134 std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override;
134 135
@@ -156,6 +157,11 @@ public:
156 template <typename T> 157 template <typename T>
157 Errno SetSockOpt(SOCKET fd, int option, T value); 158 Errno SetSockOpt(SOCKET fd, int option, T value);
158 159
160 std::pair<Errno, Errno> GetPendingError() override;
161
162 template <typename T>
163 std::pair<T, Errno> GetSockOpt(SOCKET fd, int option);
164
159 bool IsOpened() const override; 165 bool IsOpened() const override;
160 166
161 void HandleProxyPacket(const ProxyPacket& packet) override; 167 void HandleProxyPacket(const ProxyPacket& packet) override;