diff options
23 files changed, 2413 insertions, 282 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 7f8febb90..647219052 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt | |||
| @@ -63,6 +63,18 @@ option(YUZU_DOWNLOAD_TIME_ZONE_DATA "Always download time zone binaries" OFF) | |||
| 63 | 63 | ||
| 64 | CMAKE_DEPENDENT_OPTION(YUZU_USE_FASTER_LD "Check if a faster linker is available" ON "NOT WIN32" OFF) | 64 | CMAKE_DEPENDENT_OPTION(YUZU_USE_FASTER_LD "Check if a faster linker is available" ON "NOT WIN32" OFF) |
| 65 | 65 | ||
| 66 | set(DEFAULT_ENABLE_OPENSSL ON) | ||
| 67 | if (ANDROID OR WIN32 OR APPLE) | ||
| 68 | # - Windows defaults to the Schannel backend. | ||
| 69 | # - macOS defaults to the SecureTransport backend. | ||
| 70 | # - Android currently has no SSL backend as the NDK doesn't include any SSL | ||
| 71 | # library; a proper 'native' backend would have to go through Java. | ||
| 72 | # But you can force builds for those platforms to use OpenSSL if you have | ||
| 73 | # your own copy of it. | ||
| 74 | set(DEFAULT_ENABLE_OPENSSL OFF) | ||
| 75 | endif() | ||
| 76 | option(ENABLE_OPENSSL "Enable OpenSSL backend for ISslConnection" ${DEFAULT_ENABLE_OPENSSL}) | ||
| 77 | |||
| 66 | # On Android, fetch and compile libcxx before doing anything else | 78 | # On Android, fetch and compile libcxx before doing anything else |
| 67 | if (ANDROID) | 79 | if (ANDROID) |
| 68 | set(CMAKE_SKIP_INSTALL_RULES ON) | 80 | set(CMAKE_SKIP_INSTALL_RULES ON) |
| @@ -322,6 +334,10 @@ if (MINGW) | |||
| 322 | find_library(MSWSOCK_LIBRARY mswsock REQUIRED) | 334 | find_library(MSWSOCK_LIBRARY mswsock REQUIRED) |
| 323 | endif() | 335 | endif() |
| 324 | 336 | ||
| 337 | if(ENABLE_OPENSSL) | ||
| 338 | find_package(OpenSSL 1.1.1 REQUIRED) | ||
| 339 | endif() | ||
| 340 | |||
| 325 | # Please consider this as a stub | 341 | # Please consider this as a stub |
| 326 | if(ENABLE_QT6 AND Qt6_LOCATION) | 342 | if(ENABLE_QT6 AND Qt6_LOCATION) |
| 327 | list(APPEND CMAKE_PREFIX_PATH "${Qt6_LOCATION}") | 343 | list(APPEND CMAKE_PREFIX_PATH "${Qt6_LOCATION}") |
diff --git a/src/common/socket_types.h b/src/common/socket_types.h index 0a801a443..b2191c2e8 100644 --- a/src/common/socket_types.h +++ b/src/common/socket_types.h | |||
| @@ -5,15 +5,19 @@ | |||
| 5 | 5 | ||
| 6 | #include "common/common_types.h" | 6 | #include "common/common_types.h" |
| 7 | 7 | ||
| 8 | #include <optional> | ||
| 9 | |||
| 8 | namespace Network { | 10 | namespace Network { |
| 9 | 11 | ||
| 10 | /// Address families | 12 | /// Address families |
| 11 | enum class Domain : u8 { | 13 | enum class Domain : u8 { |
| 12 | INET, ///< Address family for IPv4 | 14 | Unspecified, ///< Represents 0, used in getaddrinfo hints |
| 15 | INET, ///< Address family for IPv4 | ||
| 13 | }; | 16 | }; |
| 14 | 17 | ||
| 15 | /// Socket types | 18 | /// Socket types |
| 16 | enum class Type { | 19 | enum class Type { |
| 20 | Unspecified, ///< Represents 0, used in getaddrinfo hints | ||
| 17 | STREAM, | 21 | STREAM, |
| 18 | DGRAM, | 22 | DGRAM, |
| 19 | RAW, | 23 | RAW, |
| @@ -22,6 +26,7 @@ enum class Type { | |||
| 22 | 26 | ||
| 23 | /// Protocol values for sockets | 27 | /// Protocol values for sockets |
| 24 | enum class Protocol : u8 { | 28 | enum class Protocol : u8 { |
| 29 | Unspecified, ///< Represents 0, usable in various places | ||
| 25 | ICMP, | 30 | ICMP, |
| 26 | TCP, | 31 | TCP, |
| 27 | UDP, | 32 | UDP, |
| @@ -48,4 +53,13 @@ constexpr u32 FLAG_MSG_PEEK = 0x2; | |||
| 48 | constexpr u32 FLAG_MSG_DONTWAIT = 0x80; | 53 | constexpr u32 FLAG_MSG_DONTWAIT = 0x80; |
| 49 | constexpr u32 FLAG_O_NONBLOCK = 0x800; | 54 | constexpr u32 FLAG_O_NONBLOCK = 0x800; |
| 50 | 55 | ||
| 56 | /// Cross-platform addrinfo structure | ||
| 57 | struct AddrInfo { | ||
| 58 | Domain family; | ||
| 59 | Type socket_type; | ||
| 60 | Protocol protocol; | ||
| 61 | SockAddrIn addr; | ||
| 62 | std::optional<std::string> canon_name; | ||
| 63 | }; | ||
| 64 | |||
| 51 | } // namespace Network | 65 | } // namespace Network |
diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 28cb6f86f..c3b688c5d 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt | |||
| @@ -723,6 +723,7 @@ add_library(core STATIC | |||
| 723 | hle/service/spl/spl_types.h | 723 | hle/service/spl/spl_types.h |
| 724 | hle/service/ssl/ssl.cpp | 724 | hle/service/ssl/ssl.cpp |
| 725 | hle/service/ssl/ssl.h | 725 | hle/service/ssl/ssl.h |
| 726 | hle/service/ssl/ssl_backend.h | ||
| 726 | hle/service/time/clock_types.h | 727 | hle/service/time/clock_types.h |
| 727 | hle/service/time/ephemeral_network_system_clock_context_writer.h | 728 | hle/service/time/ephemeral_network_system_clock_context_writer.h |
| 728 | hle/service/time/ephemeral_network_system_clock_core.h | 729 | hle/service/time/ephemeral_network_system_clock_core.h |
| @@ -864,6 +865,23 @@ if (ARCHITECTURE_x86_64 OR ARCHITECTURE_arm64) | |||
| 864 | target_link_libraries(core PRIVATE dynarmic::dynarmic) | 865 | target_link_libraries(core PRIVATE dynarmic::dynarmic) |
| 865 | endif() | 866 | endif() |
| 866 | 867 | ||
| 868 | if(ENABLE_OPENSSL) | ||
| 869 | target_sources(core PRIVATE | ||
| 870 | hle/service/ssl/ssl_backend_openssl.cpp) | ||
| 871 | target_link_libraries(core PRIVATE OpenSSL::SSL) | ||
| 872 | elseif (APPLE) | ||
| 873 | target_sources(core PRIVATE | ||
| 874 | hle/service/ssl/ssl_backend_securetransport.cpp) | ||
| 875 | target_link_libraries(core PRIVATE "-framework Security") | ||
| 876 | elseif (WIN32) | ||
| 877 | target_sources(core PRIVATE | ||
| 878 | hle/service/ssl/ssl_backend_schannel.cpp) | ||
| 879 | target_link_libraries(core PRIVATE secur32) | ||
| 880 | else() | ||
| 881 | target_sources(core PRIVATE | ||
| 882 | hle/service/ssl/ssl_backend_none.cpp) | ||
| 883 | endif() | ||
| 884 | |||
| 867 | if (YUZU_USE_PRECOMPILED_HEADERS) | 885 | if (YUZU_USE_PRECOMPILED_HEADERS) |
| 868 | target_precompile_headers(core PRIVATE precompiled_headers.h) | 886 | target_precompile_headers(core PRIVATE precompiled_headers.h) |
| 869 | endif() | 887 | endif() |
diff --git a/src/core/hle/service/sockets/bsd.cpp b/src/core/hle/service/sockets/bsd.cpp index bce45d321..e63b0a357 100644 --- a/src/core/hle/service/sockets/bsd.cpp +++ b/src/core/hle/service/sockets/bsd.cpp | |||
| @@ -20,6 +20,9 @@ | |||
| 20 | #include "core/internal_network/sockets.h" | 20 | #include "core/internal_network/sockets.h" |
| 21 | #include "network/network.h" | 21 | #include "network/network.h" |
| 22 | 22 | ||
| 23 | using Common::Expected; | ||
| 24 | using Common::Unexpected; | ||
| 25 | |||
| 23 | namespace Service::Sockets { | 26 | namespace Service::Sockets { |
| 24 | 27 | ||
| 25 | namespace { | 28 | namespace { |
| @@ -265,16 +268,19 @@ void BSD::GetSockOpt(HLERequestContext& ctx) { | |||
| 265 | const u32 level = rp.Pop<u32>(); | 268 | const u32 level = rp.Pop<u32>(); |
| 266 | const auto optname = static_cast<OptName>(rp.Pop<u32>()); | 269 | const auto optname = static_cast<OptName>(rp.Pop<u32>()); |
| 267 | 270 | ||
| 268 | LOG_WARNING(Service, "(STUBBED) called. fd={} level={} optname=0x{:x}", fd, level, optname); | ||
| 269 | |||
| 270 | std::vector<u8> optval(ctx.GetWriteBufferSize()); | 271 | std::vector<u8> optval(ctx.GetWriteBufferSize()); |
| 271 | 272 | ||
| 273 | LOG_DEBUG(Service, "called. fd={} level={} optname=0x{:x} len=0x{:x}", fd, level, optname, | ||
| 274 | optval.size()); | ||
| 275 | |||
| 276 | const Errno err = GetSockOptImpl(fd, level, optname, optval); | ||
| 277 | |||
| 272 | ctx.WriteBuffer(optval); | 278 | ctx.WriteBuffer(optval); |
| 273 | 279 | ||
| 274 | IPC::ResponseBuilder rb{ctx, 5}; | 280 | IPC::ResponseBuilder rb{ctx, 5}; |
| 275 | rb.Push(ResultSuccess); | 281 | rb.Push(ResultSuccess); |
| 276 | rb.Push<s32>(-1); | 282 | rb.Push<s32>(err == Errno::SUCCESS ? 0 : -1); |
| 277 | rb.PushEnum(Errno::NOTCONN); | 283 | rb.PushEnum(err); |
| 278 | rb.Push<u32>(static_cast<u32>(optval.size())); | 284 | rb.Push<u32>(static_cast<u32>(optval.size())); |
| 279 | } | 285 | } |
| 280 | 286 | ||
| @@ -436,6 +442,31 @@ void BSD::Close(HLERequestContext& ctx) { | |||
| 436 | BuildErrnoResponse(ctx, CloseImpl(fd)); | 442 | BuildErrnoResponse(ctx, CloseImpl(fd)); |
| 437 | } | 443 | } |
| 438 | 444 | ||
| 445 | void BSD::DuplicateSocket(HLERequestContext& ctx) { | ||
| 446 | struct InputParameters { | ||
| 447 | s32 fd; | ||
| 448 | u64 reserved; | ||
| 449 | }; | ||
| 450 | static_assert(sizeof(InputParameters) == 0x10); | ||
| 451 | |||
| 452 | struct OutputParameters { | ||
| 453 | s32 ret; | ||
| 454 | Errno bsd_errno; | ||
| 455 | }; | ||
| 456 | static_assert(sizeof(OutputParameters) == 0x8); | ||
| 457 | |||
| 458 | IPC::RequestParser rp{ctx}; | ||
| 459 | auto input = rp.PopRaw<InputParameters>(); | ||
| 460 | |||
| 461 | Expected<s32, Errno> res = DuplicateSocketImpl(input.fd); | ||
| 462 | IPC::ResponseBuilder rb{ctx, 4}; | ||
| 463 | rb.Push(ResultSuccess); | ||
| 464 | rb.PushRaw(OutputParameters{ | ||
| 465 | .ret = res.value_or(0), | ||
| 466 | .bsd_errno = res ? Errno::SUCCESS : res.error(), | ||
| 467 | }); | ||
| 468 | } | ||
| 469 | |||
| 439 | void BSD::EventFd(HLERequestContext& ctx) { | 470 | void BSD::EventFd(HLERequestContext& ctx) { |
| 440 | IPC::RequestParser rp{ctx}; | 471 | IPC::RequestParser rp{ctx}; |
| 441 | const u64 initval = rp.Pop<u64>(); | 472 | const u64 initval = rp.Pop<u64>(); |
| @@ -477,12 +508,12 @@ std::pair<s32, Errno> BSD::SocketImpl(Domain domain, Type type, Protocol protoco | |||
| 477 | 508 | ||
| 478 | auto room_member = room_network.GetRoomMember().lock(); | 509 | auto room_member = room_network.GetRoomMember().lock(); |
| 479 | if (room_member && room_member->IsConnected()) { | 510 | if (room_member && room_member->IsConnected()) { |
| 480 | descriptor.socket = std::make_unique<Network::ProxySocket>(room_network); | 511 | descriptor.socket = std::make_shared<Network::ProxySocket>(room_network); |
| 481 | } else { | 512 | } else { |
| 482 | descriptor.socket = std::make_unique<Network::Socket>(); | 513 | descriptor.socket = std::make_shared<Network::Socket>(); |
| 483 | } | 514 | } |
| 484 | 515 | ||
| 485 | descriptor.socket->Initialize(Translate(domain), Translate(type), Translate(type, protocol)); | 516 | descriptor.socket->Initialize(Translate(domain), Translate(type), Translate(protocol)); |
| 486 | descriptor.is_connection_based = IsConnectionBased(type); | 517 | descriptor.is_connection_based = IsConnectionBased(type); |
| 487 | 518 | ||
| 488 | return {fd, Errno::SUCCESS}; | 519 | return {fd, Errno::SUCCESS}; |
| @@ -538,7 +569,7 @@ std::pair<s32, Errno> BSD::PollImpl(std::vector<u8>& write_buffer, std::span<con | |||
| 538 | std::transform(fds.begin(), fds.end(), host_pollfds.begin(), [this](PollFD pollfd) { | 569 | std::transform(fds.begin(), fds.end(), host_pollfds.begin(), [this](PollFD pollfd) { |
| 539 | Network::PollFD result; | 570 | Network::PollFD result; |
| 540 | result.socket = file_descriptors[pollfd.fd]->socket.get(); | 571 | result.socket = file_descriptors[pollfd.fd]->socket.get(); |
| 541 | result.events = TranslatePollEventsToHost(pollfd.events); | 572 | result.events = Translate(pollfd.events); |
| 542 | result.revents = Network::PollEvents{}; | 573 | result.revents = Network::PollEvents{}; |
| 543 | return result; | 574 | return result; |
| 544 | }); | 575 | }); |
| @@ -547,7 +578,7 @@ std::pair<s32, Errno> BSD::PollImpl(std::vector<u8>& write_buffer, std::span<con | |||
| 547 | 578 | ||
| 548 | const size_t num = host_pollfds.size(); | 579 | const size_t num = host_pollfds.size(); |
| 549 | for (size_t i = 0; i < num; ++i) { | 580 | for (size_t i = 0; i < num; ++i) { |
| 550 | fds[i].revents = TranslatePollEventsToGuest(host_pollfds[i].revents); | 581 | fds[i].revents = Translate(host_pollfds[i].revents); |
| 551 | } | 582 | } |
| 552 | std::memcpy(write_buffer.data(), fds.data(), length); | 583 | std::memcpy(write_buffer.data(), fds.data(), length); |
| 553 | 584 | ||
| @@ -617,7 +648,8 @@ Errno BSD::GetPeerNameImpl(s32 fd, std::vector<u8>& write_buffer) { | |||
| 617 | } | 648 | } |
| 618 | const SockAddrIn guest_addrin = Translate(addr_in); | 649 | const SockAddrIn guest_addrin = Translate(addr_in); |
| 619 | 650 | ||
| 620 | ASSERT(write_buffer.size() == sizeof(guest_addrin)); | 651 | ASSERT(write_buffer.size() >= sizeof(guest_addrin)); |
| 652 | write_buffer.resize(sizeof(guest_addrin)); | ||
| 621 | std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin)); | 653 | std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin)); |
| 622 | return Translate(bsd_errno); | 654 | return Translate(bsd_errno); |
| 623 | } | 655 | } |
| @@ -633,7 +665,8 @@ Errno BSD::GetSockNameImpl(s32 fd, std::vector<u8>& write_buffer) { | |||
| 633 | } | 665 | } |
| 634 | const SockAddrIn guest_addrin = Translate(addr_in); | 666 | const SockAddrIn guest_addrin = Translate(addr_in); |
| 635 | 667 | ||
| 636 | ASSERT(write_buffer.size() == sizeof(guest_addrin)); | 668 | ASSERT(write_buffer.size() >= sizeof(guest_addrin)); |
| 669 | write_buffer.resize(sizeof(guest_addrin)); | ||
| 637 | std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin)); | 670 | std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin)); |
| 638 | return Translate(bsd_errno); | 671 | return Translate(bsd_errno); |
| 639 | } | 672 | } |
| @@ -671,13 +704,47 @@ std::pair<s32, Errno> BSD::FcntlImpl(s32 fd, FcntlCmd cmd, s32 arg) { | |||
| 671 | } | 704 | } |
| 672 | } | 705 | } |
| 673 | 706 | ||
| 674 | Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval) { | 707 | Errno BSD::GetSockOptImpl(s32 fd, u32 level, OptName optname, std::vector<u8>& optval) { |
| 675 | UNIMPLEMENTED_IF(level != 0xffff); // SOL_SOCKET | 708 | if (!IsFileDescriptorValid(fd)) { |
| 709 | return Errno::BADF; | ||
| 710 | } | ||
| 711 | |||
| 712 | if (level != static_cast<u32>(SocketLevel::SOCKET)) { | ||
| 713 | UNIMPLEMENTED_MSG("Unknown getsockopt level"); | ||
| 714 | return Errno::SUCCESS; | ||
| 715 | } | ||
| 716 | |||
| 717 | Network::SocketBase* const socket = file_descriptors[fd]->socket.get(); | ||
| 718 | |||
| 719 | switch (optname) { | ||
| 720 | case OptName::ERROR_: { | ||
| 721 | auto [pending_err, getsockopt_err] = socket->GetPendingError(); | ||
| 722 | if (getsockopt_err == Network::Errno::SUCCESS) { | ||
| 723 | Errno translated_pending_err = Translate(pending_err); | ||
| 724 | ASSERT_OR_EXECUTE_MSG( | ||
| 725 | optval.size() == sizeof(Errno), { return Errno::INVAL; }, | ||
| 726 | "Incorrect getsockopt option size"); | ||
| 727 | optval.resize(sizeof(Errno)); | ||
| 728 | memcpy(optval.data(), &translated_pending_err, sizeof(Errno)); | ||
| 729 | } | ||
| 730 | return Translate(getsockopt_err); | ||
| 731 | } | ||
| 732 | default: | ||
| 733 | UNIMPLEMENTED_MSG("Unimplemented optname={}", optname); | ||
| 734 | return Errno::SUCCESS; | ||
| 735 | } | ||
| 736 | } | ||
| 676 | 737 | ||
| 738 | Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval) { | ||
| 677 | if (!IsFileDescriptorValid(fd)) { | 739 | if (!IsFileDescriptorValid(fd)) { |
| 678 | return Errno::BADF; | 740 | return Errno::BADF; |
| 679 | } | 741 | } |
| 680 | 742 | ||
| 743 | if (level != static_cast<u32>(SocketLevel::SOCKET)) { | ||
| 744 | UNIMPLEMENTED_MSG("Unknown setsockopt level"); | ||
| 745 | return Errno::SUCCESS; | ||
| 746 | } | ||
| 747 | |||
| 681 | Network::SocketBase* const socket = file_descriptors[fd]->socket.get(); | 748 | Network::SocketBase* const socket = file_descriptors[fd]->socket.get(); |
| 682 | 749 | ||
| 683 | if (optname == OptName::LINGER) { | 750 | if (optname == OptName::LINGER) { |
| @@ -711,6 +778,9 @@ Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, con | |||
| 711 | return Translate(socket->SetSndTimeo(value)); | 778 | return Translate(socket->SetSndTimeo(value)); |
| 712 | case OptName::RCVTIMEO: | 779 | case OptName::RCVTIMEO: |
| 713 | return Translate(socket->SetRcvTimeo(value)); | 780 | return Translate(socket->SetRcvTimeo(value)); |
| 781 | case OptName::NOSIGPIPE: | ||
| 782 | LOG_WARNING(Service, "(STUBBED) setting NOSIGPIPE to {}", value); | ||
| 783 | return Errno::SUCCESS; | ||
| 714 | default: | 784 | default: |
| 715 | UNIMPLEMENTED_MSG("Unimplemented optname={}", optname); | 785 | UNIMPLEMENTED_MSG("Unimplemented optname={}", optname); |
| 716 | return Errno::SUCCESS; | 786 | return Errno::SUCCESS; |
| @@ -841,6 +911,28 @@ Errno BSD::CloseImpl(s32 fd) { | |||
| 841 | return bsd_errno; | 911 | return bsd_errno; |
| 842 | } | 912 | } |
| 843 | 913 | ||
| 914 | Expected<s32, Errno> BSD::DuplicateSocketImpl(s32 fd) { | ||
| 915 | if (!IsFileDescriptorValid(fd)) { | ||
| 916 | return Unexpected(Errno::BADF); | ||
| 917 | } | ||
| 918 | |||
| 919 | const s32 new_fd = FindFreeFileDescriptorHandle(); | ||
| 920 | if (new_fd < 0) { | ||
| 921 | LOG_ERROR(Service, "No more file descriptors available"); | ||
| 922 | return Unexpected(Errno::MFILE); | ||
| 923 | } | ||
| 924 | |||
| 925 | file_descriptors[new_fd] = file_descriptors[fd]; | ||
| 926 | return new_fd; | ||
| 927 | } | ||
| 928 | |||
| 929 | std::optional<std::shared_ptr<Network::SocketBase>> BSD::GetSocket(s32 fd) { | ||
| 930 | if (!IsFileDescriptorValid(fd)) { | ||
| 931 | return std::nullopt; | ||
| 932 | } | ||
| 933 | return file_descriptors[fd]->socket; | ||
| 934 | } | ||
| 935 | |||
| 844 | s32 BSD::FindFreeFileDescriptorHandle() noexcept { | 936 | s32 BSD::FindFreeFileDescriptorHandle() noexcept { |
| 845 | for (s32 fd = 0; fd < static_cast<s32>(file_descriptors.size()); ++fd) { | 937 | for (s32 fd = 0; fd < static_cast<s32>(file_descriptors.size()); ++fd) { |
| 846 | if (!file_descriptors[fd]) { | 938 | if (!file_descriptors[fd]) { |
| @@ -911,7 +1003,7 @@ BSD::BSD(Core::System& system_, const char* name) | |||
| 911 | {24, &BSD::Write, "Write"}, | 1003 | {24, &BSD::Write, "Write"}, |
| 912 | {25, &BSD::Read, "Read"}, | 1004 | {25, &BSD::Read, "Read"}, |
| 913 | {26, &BSD::Close, "Close"}, | 1005 | {26, &BSD::Close, "Close"}, |
| 914 | {27, nullptr, "DuplicateSocket"}, | 1006 | {27, &BSD::DuplicateSocket, "DuplicateSocket"}, |
| 915 | {28, nullptr, "GetResourceStatistics"}, | 1007 | {28, nullptr, "GetResourceStatistics"}, |
| 916 | {29, nullptr, "RecvMMsg"}, | 1008 | {29, nullptr, "RecvMMsg"}, |
| 917 | {30, nullptr, "SendMMsg"}, | 1009 | {30, nullptr, "SendMMsg"}, |
diff --git a/src/core/hle/service/sockets/bsd.h b/src/core/hle/service/sockets/bsd.h index 30ae9c140..430edb97c 100644 --- a/src/core/hle/service/sockets/bsd.h +++ b/src/core/hle/service/sockets/bsd.h | |||
| @@ -8,6 +8,7 @@ | |||
| 8 | #include <vector> | 8 | #include <vector> |
| 9 | 9 | ||
| 10 | #include "common/common_types.h" | 10 | #include "common/common_types.h" |
| 11 | #include "common/expected.h" | ||
| 11 | #include "common/socket_types.h" | 12 | #include "common/socket_types.h" |
| 12 | #include "core/hle/service/service.h" | 13 | #include "core/hle/service/service.h" |
| 13 | #include "core/hle/service/sockets/sockets.h" | 14 | #include "core/hle/service/sockets/sockets.h" |
| @@ -29,12 +30,19 @@ public: | |||
| 29 | explicit BSD(Core::System& system_, const char* name); | 30 | explicit BSD(Core::System& system_, const char* name); |
| 30 | ~BSD() override; | 31 | ~BSD() override; |
| 31 | 32 | ||
| 33 | // These methods are called from SSL; the first two are also called from | ||
| 34 | // this class for the corresponding IPC methods. | ||
| 35 | // On the real device, the SSL service makes IPC calls to this service. | ||
| 36 | Common::Expected<s32, Errno> DuplicateSocketImpl(s32 fd); | ||
| 37 | Errno CloseImpl(s32 fd); | ||
| 38 | std::optional<std::shared_ptr<Network::SocketBase>> GetSocket(s32 fd); | ||
| 39 | |||
| 32 | private: | 40 | private: |
| 33 | /// Maximum number of file descriptors | 41 | /// Maximum number of file descriptors |
| 34 | static constexpr size_t MAX_FD = 128; | 42 | static constexpr size_t MAX_FD = 128; |
| 35 | 43 | ||
| 36 | struct FileDescriptor { | 44 | struct FileDescriptor { |
| 37 | std::unique_ptr<Network::SocketBase> socket; | 45 | std::shared_ptr<Network::SocketBase> socket; |
| 38 | s32 flags = 0; | 46 | s32 flags = 0; |
| 39 | bool is_connection_based = false; | 47 | bool is_connection_based = false; |
| 40 | }; | 48 | }; |
| @@ -138,6 +146,7 @@ private: | |||
| 138 | void Write(HLERequestContext& ctx); | 146 | void Write(HLERequestContext& ctx); |
| 139 | void Read(HLERequestContext& ctx); | 147 | void Read(HLERequestContext& ctx); |
| 140 | void Close(HLERequestContext& ctx); | 148 | void Close(HLERequestContext& ctx); |
| 149 | void DuplicateSocket(HLERequestContext& ctx); | ||
| 141 | void EventFd(HLERequestContext& ctx); | 150 | void EventFd(HLERequestContext& ctx); |
| 142 | 151 | ||
| 143 | template <typename Work> | 152 | template <typename Work> |
| @@ -153,6 +162,7 @@ private: | |||
| 153 | Errno GetSockNameImpl(s32 fd, std::vector<u8>& write_buffer); | 162 | Errno GetSockNameImpl(s32 fd, std::vector<u8>& write_buffer); |
| 154 | Errno ListenImpl(s32 fd, s32 backlog); | 163 | Errno ListenImpl(s32 fd, s32 backlog); |
| 155 | std::pair<s32, Errno> FcntlImpl(s32 fd, FcntlCmd cmd, s32 arg); | 164 | std::pair<s32, Errno> FcntlImpl(s32 fd, FcntlCmd cmd, s32 arg); |
| 165 | Errno GetSockOptImpl(s32 fd, u32 level, OptName optname, std::vector<u8>& optval); | ||
| 156 | Errno SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval); | 166 | Errno SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval); |
| 157 | Errno ShutdownImpl(s32 fd, s32 how); | 167 | Errno ShutdownImpl(s32 fd, s32 how); |
| 158 | std::pair<s32, Errno> RecvImpl(s32 fd, u32 flags, std::vector<u8>& message); | 168 | std::pair<s32, Errno> RecvImpl(s32 fd, u32 flags, std::vector<u8>& message); |
| @@ -161,7 +171,6 @@ private: | |||
| 161 | std::pair<s32, Errno> SendImpl(s32 fd, u32 flags, std::span<const u8> message); | 171 | std::pair<s32, Errno> SendImpl(s32 fd, u32 flags, std::span<const u8> message); |
| 162 | std::pair<s32, Errno> SendToImpl(s32 fd, u32 flags, std::span<const u8> message, | 172 | std::pair<s32, Errno> SendToImpl(s32 fd, u32 flags, std::span<const u8> message, |
| 163 | std::span<const u8> addr); | 173 | std::span<const u8> addr); |
| 164 | Errno CloseImpl(s32 fd); | ||
| 165 | 174 | ||
| 166 | s32 FindFreeFileDescriptorHandle() noexcept; | 175 | s32 FindFreeFileDescriptorHandle() noexcept; |
| 167 | bool IsFileDescriptorValid(s32 fd) const noexcept; | 176 | bool IsFileDescriptorValid(s32 fd) const noexcept; |
diff --git a/src/core/hle/service/sockets/nsd.cpp b/src/core/hle/service/sockets/nsd.cpp index 6491a73be..0dfb0f166 100644 --- a/src/core/hle/service/sockets/nsd.cpp +++ b/src/core/hle/service/sockets/nsd.cpp | |||
| @@ -1,10 +1,15 @@ | |||
| 1 | // SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project | 1 | // SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project |
| 2 | // SPDX-License-Identifier: GPL-2.0-or-later | 2 | // SPDX-License-Identifier: GPL-2.0-or-later |
| 3 | 3 | ||
| 4 | #include "core/hle/service/ipc_helpers.h" | ||
| 4 | #include "core/hle/service/sockets/nsd.h" | 5 | #include "core/hle/service/sockets/nsd.h" |
| 5 | 6 | ||
| 7 | #include "common/string_util.h" | ||
| 8 | |||
| 6 | namespace Service::Sockets { | 9 | namespace Service::Sockets { |
| 7 | 10 | ||
| 11 | constexpr Result ResultOverflow{ErrorModule::NSD, 6}; | ||
| 12 | |||
| 8 | NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, name} { | 13 | NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, name} { |
| 9 | // clang-format off | 14 | // clang-format off |
| 10 | static const FunctionInfo functions[] = { | 15 | static const FunctionInfo functions[] = { |
| @@ -15,8 +20,8 @@ NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, na | |||
| 15 | {13, nullptr, "DeleteSettings"}, | 20 | {13, nullptr, "DeleteSettings"}, |
| 16 | {14, nullptr, "ImportSettings"}, | 21 | {14, nullptr, "ImportSettings"}, |
| 17 | {15, nullptr, "SetChangeEnvironmentIdentifierDisabled"}, | 22 | {15, nullptr, "SetChangeEnvironmentIdentifierDisabled"}, |
| 18 | {20, nullptr, "Resolve"}, | 23 | {20, &NSD::Resolve, "Resolve"}, |
| 19 | {21, nullptr, "ResolveEx"}, | 24 | {21, &NSD::ResolveEx, "ResolveEx"}, |
| 20 | {30, nullptr, "GetNasServiceSetting"}, | 25 | {30, nullptr, "GetNasServiceSetting"}, |
| 21 | {31, nullptr, "GetNasServiceSettingEx"}, | 26 | {31, nullptr, "GetNasServiceSettingEx"}, |
| 22 | {40, nullptr, "GetNasRequestFqdn"}, | 27 | {40, nullptr, "GetNasRequestFqdn"}, |
| @@ -40,6 +45,55 @@ NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, na | |||
| 40 | RegisterHandlers(functions); | 45 | RegisterHandlers(functions); |
| 41 | } | 46 | } |
| 42 | 47 | ||
| 48 | static ResultVal<std::string> ResolveImpl(const std::string& fqdn_in) { | ||
| 49 | // The real implementation makes various substitutions. | ||
| 50 | // For now we just return the string as-is, which is good enough when not | ||
| 51 | // connecting to real Nintendo servers. | ||
| 52 | LOG_WARNING(Service, "(STUBBED) called, fqdn_in={}", fqdn_in); | ||
| 53 | return fqdn_in; | ||
| 54 | } | ||
| 55 | |||
| 56 | static Result ResolveCommon(const std::string& fqdn_in, std::array<char, 0x100>& fqdn_out) { | ||
| 57 | const auto res = ResolveImpl(fqdn_in); | ||
| 58 | if (res.Failed()) { | ||
| 59 | return res.Code(); | ||
| 60 | } | ||
| 61 | if (res->size() >= fqdn_out.size()) { | ||
| 62 | return ResultOverflow; | ||
| 63 | } | ||
| 64 | std::memcpy(fqdn_out.data(), res->c_str(), res->size() + 1); | ||
| 65 | return ResultSuccess; | ||
| 66 | } | ||
| 67 | |||
| 68 | void NSD::Resolve(HLERequestContext& ctx) { | ||
| 69 | const std::string fqdn_in = Common::StringFromBuffer(ctx.ReadBuffer(0)); | ||
| 70 | |||
| 71 | std::array<char, 0x100> fqdn_out{}; | ||
| 72 | const Result res = ResolveCommon(fqdn_in, fqdn_out); | ||
| 73 | |||
| 74 | ctx.WriteBuffer(fqdn_out); | ||
| 75 | IPC::ResponseBuilder rb{ctx, 2}; | ||
| 76 | rb.Push(res); | ||
| 77 | } | ||
| 78 | |||
| 79 | void NSD::ResolveEx(HLERequestContext& ctx) { | ||
| 80 | const std::string fqdn_in = Common::StringFromBuffer(ctx.ReadBuffer(0)); | ||
| 81 | |||
| 82 | std::array<char, 0x100> fqdn_out; | ||
| 83 | const Result res = ResolveCommon(fqdn_in, fqdn_out); | ||
| 84 | |||
| 85 | if (res.IsError()) { | ||
| 86 | IPC::ResponseBuilder rb{ctx, 2}; | ||
| 87 | rb.Push(res); | ||
| 88 | return; | ||
| 89 | } | ||
| 90 | |||
| 91 | ctx.WriteBuffer(fqdn_out); | ||
| 92 | IPC::ResponseBuilder rb{ctx, 4}; | ||
| 93 | rb.Push(ResultSuccess); | ||
| 94 | rb.Push(ResultSuccess); | ||
| 95 | } | ||
| 96 | |||
| 43 | NSD::~NSD() = default; | 97 | NSD::~NSD() = default; |
| 44 | 98 | ||
| 45 | } // namespace Service::Sockets | 99 | } // namespace Service::Sockets |
diff --git a/src/core/hle/service/sockets/nsd.h b/src/core/hle/service/sockets/nsd.h index 5cc12b855..a7379a8a9 100644 --- a/src/core/hle/service/sockets/nsd.h +++ b/src/core/hle/service/sockets/nsd.h | |||
| @@ -15,6 +15,10 @@ class NSD final : public ServiceFramework<NSD> { | |||
| 15 | public: | 15 | public: |
| 16 | explicit NSD(Core::System& system_, const char* name); | 16 | explicit NSD(Core::System& system_, const char* name); |
| 17 | ~NSD() override; | 17 | ~NSD() override; |
| 18 | |||
| 19 | private: | ||
| 20 | void Resolve(HLERequestContext& ctx); | ||
| 21 | void ResolveEx(HLERequestContext& ctx); | ||
| 18 | }; | 22 | }; |
| 19 | 23 | ||
| 20 | } // namespace Service::Sockets | 24 | } // namespace Service::Sockets |
diff --git a/src/core/hle/service/sockets/sfdnsres.cpp b/src/core/hle/service/sockets/sfdnsres.cpp index 132dd5797..84cc79de8 100644 --- a/src/core/hle/service/sockets/sfdnsres.cpp +++ b/src/core/hle/service/sockets/sfdnsres.cpp | |||
| @@ -10,27 +10,18 @@ | |||
| 10 | #include "core/core.h" | 10 | #include "core/core.h" |
| 11 | #include "core/hle/service/ipc_helpers.h" | 11 | #include "core/hle/service/ipc_helpers.h" |
| 12 | #include "core/hle/service/sockets/sfdnsres.h" | 12 | #include "core/hle/service/sockets/sfdnsres.h" |
| 13 | #include "core/hle/service/sockets/sockets.h" | ||
| 14 | #include "core/hle/service/sockets/sockets_translate.h" | ||
| 15 | #include "core/internal_network/network.h" | ||
| 13 | #include "core/memory.h" | 16 | #include "core/memory.h" |
| 14 | 17 | ||
| 15 | #ifdef _WIN32 | ||
| 16 | #include <ws2tcpip.h> | ||
| 17 | #elif YUZU_UNIX | ||
| 18 | #include <arpa/inet.h> | ||
| 19 | #include <netdb.h> | ||
| 20 | #include <netinet/in.h> | ||
| 21 | #include <sys/socket.h> | ||
| 22 | #ifndef EAI_NODATA | ||
| 23 | #define EAI_NODATA EAI_NONAME | ||
| 24 | #endif | ||
| 25 | #endif | ||
| 26 | |||
| 27 | namespace Service::Sockets { | 18 | namespace Service::Sockets { |
| 28 | 19 | ||
| 29 | SFDNSRES::SFDNSRES(Core::System& system_) : ServiceFramework{system_, "sfdnsres"} { | 20 | SFDNSRES::SFDNSRES(Core::System& system_) : ServiceFramework{system_, "sfdnsres"} { |
| 30 | static const FunctionInfo functions[] = { | 21 | static const FunctionInfo functions[] = { |
| 31 | {0, nullptr, "SetDnsAddressesPrivateRequest"}, | 22 | {0, nullptr, "SetDnsAddressesPrivateRequest"}, |
| 32 | {1, nullptr, "GetDnsAddressPrivateRequest"}, | 23 | {1, nullptr, "GetDnsAddressPrivateRequest"}, |
| 33 | {2, nullptr, "GetHostByNameRequest"}, | 24 | {2, &SFDNSRES::GetHostByNameRequest, "GetHostByNameRequest"}, |
| 34 | {3, nullptr, "GetHostByAddrRequest"}, | 25 | {3, nullptr, "GetHostByAddrRequest"}, |
| 35 | {4, nullptr, "GetHostStringErrorRequest"}, | 26 | {4, nullptr, "GetHostStringErrorRequest"}, |
| 36 | {5, nullptr, "GetGaiStringErrorRequest"}, | 27 | {5, nullptr, "GetGaiStringErrorRequest"}, |
| @@ -38,11 +29,11 @@ SFDNSRES::SFDNSRES(Core::System& system_) : ServiceFramework{system_, "sfdnsres" | |||
| 38 | {7, nullptr, "GetNameInfoRequest"}, | 29 | {7, nullptr, "GetNameInfoRequest"}, |
| 39 | {8, nullptr, "RequestCancelHandleRequest"}, | 30 | {8, nullptr, "RequestCancelHandleRequest"}, |
| 40 | {9, nullptr, "CancelRequest"}, | 31 | {9, nullptr, "CancelRequest"}, |
| 41 | {10, nullptr, "GetHostByNameRequestWithOptions"}, | 32 | {10, &SFDNSRES::GetHostByNameRequestWithOptions, "GetHostByNameRequestWithOptions"}, |
| 42 | {11, nullptr, "GetHostByAddrRequestWithOptions"}, | 33 | {11, nullptr, "GetHostByAddrRequestWithOptions"}, |
| 43 | {12, &SFDNSRES::GetAddrInfoRequestWithOptions, "GetAddrInfoRequestWithOptions"}, | 34 | {12, &SFDNSRES::GetAddrInfoRequestWithOptions, "GetAddrInfoRequestWithOptions"}, |
| 44 | {13, nullptr, "GetNameInfoRequestWithOptions"}, | 35 | {13, nullptr, "GetNameInfoRequestWithOptions"}, |
| 45 | {14, nullptr, "ResolverSetOptionRequest"}, | 36 | {14, &SFDNSRES::ResolverSetOptionRequest, "ResolverSetOptionRequest"}, |
| 46 | {15, nullptr, "ResolverGetOptionRequest"}, | 37 | {15, nullptr, "ResolverGetOptionRequest"}, |
| 47 | }; | 38 | }; |
| 48 | RegisterHandlers(functions); | 39 | RegisterHandlers(functions); |
| @@ -59,188 +50,285 @@ enum class NetDbError : s32 { | |||
| 59 | NoData = 4, | 50 | NoData = 4, |
| 60 | }; | 51 | }; |
| 61 | 52 | ||
| 62 | static NetDbError AddrInfoErrorToNetDbError(s32 result) { | 53 | static NetDbError GetAddrInfoErrorToNetDbError(GetAddrInfoError result) { |
| 63 | // Best effort guess to map errors | 54 | // These combinations have been verified on console (but are not |
| 55 | // exhaustive). | ||
| 64 | switch (result) { | 56 | switch (result) { |
| 65 | case 0: | 57 | case GetAddrInfoError::SUCCESS: |
| 66 | return NetDbError::Success; | 58 | return NetDbError::Success; |
| 67 | case EAI_AGAIN: | 59 | case GetAddrInfoError::AGAIN: |
| 68 | return NetDbError::TryAgain; | 60 | return NetDbError::TryAgain; |
| 69 | case EAI_NODATA: | 61 | case GetAddrInfoError::NODATA: |
| 70 | return NetDbError::NoData; | 62 | return NetDbError::HostNotFound; |
| 63 | case GetAddrInfoError::SERVICE: | ||
| 64 | return NetDbError::Success; | ||
| 71 | default: | 65 | default: |
| 72 | return NetDbError::HostNotFound; | 66 | return NetDbError::HostNotFound; |
| 73 | } | 67 | } |
| 74 | } | 68 | } |
| 75 | 69 | ||
| 76 | static std::vector<u8> SerializeAddrInfo(const addrinfo* addrinfo, s32 result_code, | 70 | static Errno GetAddrInfoErrorToErrno(GetAddrInfoError result) { |
| 71 | // These combinations have been verified on console (but are not | ||
| 72 | // exhaustive). | ||
| 73 | switch (result) { | ||
| 74 | case GetAddrInfoError::SUCCESS: | ||
| 75 | // Note: Sometimes a successful lookup sets errno to EADDRNOTAVAIL for | ||
| 76 | // some reason, but that doesn't seem useful to implement. | ||
| 77 | return Errno::SUCCESS; | ||
| 78 | case GetAddrInfoError::AGAIN: | ||
| 79 | return Errno::SUCCESS; | ||
| 80 | case GetAddrInfoError::NODATA: | ||
| 81 | return Errno::SUCCESS; | ||
| 82 | case GetAddrInfoError::SERVICE: | ||
| 83 | return Errno::INVAL; | ||
| 84 | default: | ||
| 85 | return Errno::SUCCESS; | ||
| 86 | } | ||
| 87 | } | ||
| 88 | |||
| 89 | template <typename T> | ||
| 90 | static void Append(std::vector<u8>& vec, T t) { | ||
| 91 | const size_t offset = vec.size(); | ||
| 92 | vec.resize(offset + sizeof(T)); | ||
| 93 | std::memcpy(vec.data() + offset, &t, sizeof(T)); | ||
| 94 | } | ||
| 95 | |||
| 96 | static void AppendNulTerminated(std::vector<u8>& vec, std::string_view str) { | ||
| 97 | const size_t offset = vec.size(); | ||
| 98 | vec.resize(offset + str.size() + 1); | ||
| 99 | std::memmove(vec.data() + offset, str.data(), str.size()); | ||
| 100 | } | ||
| 101 | |||
| 102 | // We implement gethostbyname using the host's getaddrinfo rather than the | ||
| 103 | // host's gethostbyname, because it simplifies portability: e.g., getaddrinfo | ||
| 104 | // behaves the same on Unix and Windows, unlike gethostbyname where Windows | ||
| 105 | // doesn't implement h_errno. | ||
| 106 | static std::vector<u8> SerializeAddrInfoAsHostEnt(const std::vector<Network::AddrInfo>& vec, | ||
| 107 | std::string_view host) { | ||
| 108 | |||
| 109 | std::vector<u8> data; | ||
| 110 | // h_name: use the input hostname (append nul-terminated) | ||
| 111 | AppendNulTerminated(data, host); | ||
| 112 | // h_aliases: leave empty | ||
| 113 | |||
| 114 | Append<u32_be>(data, 0); // count of h_aliases | ||
| 115 | // (If the count were nonzero, the aliases would be appended as nul-terminated here.) | ||
| 116 | Append<u16_be>(data, static_cast<u16>(Domain::INET)); // h_addrtype | ||
| 117 | Append<u16_be>(data, sizeof(Network::IPv4Address)); // h_length | ||
| 118 | // h_addr_list: | ||
| 119 | size_t count = vec.size(); | ||
| 120 | ASSERT(count <= UINT32_MAX); | ||
| 121 | Append<u32_be>(data, static_cast<uint32_t>(count)); | ||
| 122 | for (const Network::AddrInfo& addrinfo : vec) { | ||
| 123 | // On the Switch, this is passed through htonl despite already being | ||
| 124 | // big-endian, so it ends up as little-endian. | ||
| 125 | Append<u32_le>(data, Network::IPv4AddressToInteger(addrinfo.addr.ip)); | ||
| 126 | |||
| 127 | LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host, | ||
| 128 | Network::IPv4AddressToString(addrinfo.addr.ip)); | ||
| 129 | } | ||
| 130 | return data; | ||
| 131 | } | ||
| 132 | |||
| 133 | static std::pair<u32, GetAddrInfoError> GetHostByNameRequestImpl(HLERequestContext& ctx) { | ||
| 134 | struct InputParameters { | ||
| 135 | u8 use_nsd_resolve; | ||
| 136 | u32 cancel_handle; | ||
| 137 | u64 process_id; | ||
| 138 | }; | ||
| 139 | static_assert(sizeof(InputParameters) == 0x10); | ||
| 140 | |||
| 141 | IPC::RequestParser rp{ctx}; | ||
| 142 | const auto parameters = rp.PopRaw<InputParameters>(); | ||
| 143 | |||
| 144 | LOG_WARNING( | ||
| 145 | Service, | ||
| 146 | "called with ignored parameters: use_nsd_resolve={}, cancel_handle={}, process_id={}", | ||
| 147 | parameters.use_nsd_resolve, parameters.cancel_handle, parameters.process_id); | ||
| 148 | |||
| 149 | const auto host_buffer = ctx.ReadBuffer(0); | ||
| 150 | const std::string host = Common::StringFromBuffer(host_buffer); | ||
| 151 | // For now, ignore options, which are in input buffer 1 for GetHostByNameRequestWithOptions. | ||
| 152 | |||
| 153 | auto res = Network::GetAddressInfo(host, /*service*/ std::nullopt); | ||
| 154 | if (!res.has_value()) { | ||
| 155 | return {0, Translate(res.error())}; | ||
| 156 | } | ||
| 157 | |||
| 158 | const std::vector<u8> data = SerializeAddrInfoAsHostEnt(res.value(), host); | ||
| 159 | const u32 data_size = static_cast<u32>(data.size()); | ||
| 160 | ctx.WriteBuffer(data, 0); | ||
| 161 | |||
| 162 | return {data_size, GetAddrInfoError::SUCCESS}; | ||
| 163 | } | ||
| 164 | |||
| 165 | void SFDNSRES::GetHostByNameRequest(HLERequestContext& ctx) { | ||
| 166 | auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx); | ||
| 167 | |||
| 168 | struct OutputParameters { | ||
| 169 | NetDbError netdb_error; | ||
| 170 | Errno bsd_errno; | ||
| 171 | u32 data_size; | ||
| 172 | }; | ||
| 173 | static_assert(sizeof(OutputParameters) == 0xc); | ||
| 174 | |||
| 175 | IPC::ResponseBuilder rb{ctx, 5}; | ||
| 176 | rb.Push(ResultSuccess); | ||
| 177 | rb.PushRaw(OutputParameters{ | ||
| 178 | .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err), | ||
| 179 | .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err), | ||
| 180 | .data_size = data_size, | ||
| 181 | }); | ||
| 182 | } | ||
| 183 | |||
| 184 | void SFDNSRES::GetHostByNameRequestWithOptions(HLERequestContext& ctx) { | ||
| 185 | auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx); | ||
| 186 | |||
| 187 | struct OutputParameters { | ||
| 188 | u32 data_size; | ||
| 189 | NetDbError netdb_error; | ||
| 190 | Errno bsd_errno; | ||
| 191 | }; | ||
| 192 | static_assert(sizeof(OutputParameters) == 0xc); | ||
| 193 | |||
| 194 | IPC::ResponseBuilder rb{ctx, 5}; | ||
| 195 | rb.Push(ResultSuccess); | ||
| 196 | rb.PushRaw(OutputParameters{ | ||
| 197 | .data_size = data_size, | ||
| 198 | .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err), | ||
| 199 | .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err), | ||
| 200 | }); | ||
| 201 | } | ||
| 202 | |||
| 203 | static std::vector<u8> SerializeAddrInfo(const std::vector<Network::AddrInfo>& vec, | ||
| 77 | std::string_view host) { | 204 | std::string_view host) { |
| 78 | // Adapted from | 205 | // Adapted from |
| 79 | // https://github.com/switchbrew/libnx/blob/c5a9a909a91657a9818a3b7e18c9b91ff0cbb6e3/nx/source/runtime/resolver.c#L190 | 206 | // https://github.com/switchbrew/libnx/blob/c5a9a909a91657a9818a3b7e18c9b91ff0cbb6e3/nx/source/runtime/resolver.c#L190 |
| 80 | std::vector<u8> data; | 207 | std::vector<u8> data; |
| 81 | 208 | ||
| 82 | auto* current = addrinfo; | 209 | for (const Network::AddrInfo& addrinfo : vec) { |
| 83 | while (current != nullptr) { | 210 | // serialized addrinfo: |
| 84 | struct SerializedResponseHeader { | 211 | Append<u32_be>(data, 0xBEEFCAFE); // magic |
| 85 | u32 magic; | 212 | Append<u32_be>(data, 0); // ai_flags |
| 86 | s32 flags; | 213 | Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.family))); // ai_family |
| 87 | s32 family; | 214 | Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.socket_type))); // ai_socktype |
| 88 | s32 socket_type; | 215 | Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.protocol))); // ai_protocol |
| 89 | s32 protocol; | 216 | Append<u32_be>(data, sizeof(SockAddrIn)); // ai_addrlen |
| 90 | u32 address_length; | 217 | // ^ *not* sizeof(SerializedSockAddrIn), not that it matters since they're the same size |
| 91 | }; | 218 | |
| 92 | static_assert(sizeof(SerializedResponseHeader) == 0x18, | 219 | // ai_addr: |
| 93 | "Response header size must be 0x18 bytes"); | 220 | Append<u16_be>(data, static_cast<u16>(Translate(addrinfo.addr.family))); // sin_family |
| 94 | 221 | // On the Switch, the following fields are passed through htonl despite | |
| 95 | constexpr auto header_size = sizeof(SerializedResponseHeader); | 222 | // already being big-endian, so they end up as little-endian. |
| 96 | const auto addr_size = | 223 | Append<u16_le>(data, addrinfo.addr.portno); // sin_port |
| 97 | current->ai_addr && current->ai_addrlen > 0 ? current->ai_addrlen : 4; | 224 | Append<u32_le>(data, Network::IPv4AddressToInteger(addrinfo.addr.ip)); // sin_addr |
| 98 | const auto canonname_size = current->ai_canonname ? strlen(current->ai_canonname) + 1 : 1; | 225 | data.resize(data.size() + 8, 0); // sin_zero |
| 99 | 226 | ||
| 100 | const auto last_size = data.size(); | 227 | if (addrinfo.canon_name.has_value()) { |
| 101 | data.resize(last_size + header_size + addr_size + canonname_size); | 228 | AppendNulTerminated(data, *addrinfo.canon_name); |
| 102 | |||
| 103 | // Header in network byte order | ||
| 104 | SerializedResponseHeader header{}; | ||
| 105 | |||
| 106 | constexpr auto HEADER_MAGIC = 0xBEEFCAFE; | ||
| 107 | header.magic = htonl(HEADER_MAGIC); | ||
| 108 | header.family = htonl(current->ai_family); | ||
| 109 | header.flags = htonl(current->ai_flags); | ||
| 110 | header.socket_type = htonl(current->ai_socktype); | ||
| 111 | header.protocol = htonl(current->ai_protocol); | ||
| 112 | header.address_length = current->ai_addr ? htonl((u32)current->ai_addrlen) : 0; | ||
| 113 | |||
| 114 | auto* header_ptr = data.data() + last_size; | ||
| 115 | std::memcpy(header_ptr, &header, header_size); | ||
| 116 | |||
| 117 | if (header.address_length == 0) { | ||
| 118 | std::memset(header_ptr + header_size, 0, 4); | ||
| 119 | } else { | ||
| 120 | switch (current->ai_family) { | ||
| 121 | case AF_INET: { | ||
| 122 | struct SockAddrIn { | ||
| 123 | s16 sin_family; | ||
| 124 | u16 sin_port; | ||
| 125 | u32 sin_addr; | ||
| 126 | u8 sin_zero[8]; | ||
| 127 | }; | ||
| 128 | |||
| 129 | SockAddrIn serialized_addr{}; | ||
| 130 | const auto addr = *reinterpret_cast<sockaddr_in*>(current->ai_addr); | ||
| 131 | serialized_addr.sin_port = htons(addr.sin_port); | ||
| 132 | serialized_addr.sin_family = htons(addr.sin_family); | ||
| 133 | serialized_addr.sin_addr = htonl(addr.sin_addr.s_addr); | ||
| 134 | std::memcpy(header_ptr + header_size, &serialized_addr, sizeof(SockAddrIn)); | ||
| 135 | |||
| 136 | char addr_string_buf[64]{}; | ||
| 137 | inet_ntop(AF_INET, &addr.sin_addr, addr_string_buf, std::size(addr_string_buf)); | ||
| 138 | LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host, addr_string_buf); | ||
| 139 | break; | ||
| 140 | } | ||
| 141 | case AF_INET6: { | ||
| 142 | struct SockAddrIn6 { | ||
| 143 | s16 sin6_family; | ||
| 144 | u16 sin6_port; | ||
| 145 | u32 sin6_flowinfo; | ||
| 146 | u8 sin6_addr[16]; | ||
| 147 | u32 sin6_scope_id; | ||
| 148 | }; | ||
| 149 | |||
| 150 | SockAddrIn6 serialized_addr{}; | ||
| 151 | const auto addr = *reinterpret_cast<sockaddr_in6*>(current->ai_addr); | ||
| 152 | serialized_addr.sin6_family = htons(addr.sin6_family); | ||
| 153 | serialized_addr.sin6_port = htons(addr.sin6_port); | ||
| 154 | serialized_addr.sin6_flowinfo = htonl(addr.sin6_flowinfo); | ||
| 155 | serialized_addr.sin6_scope_id = htonl(addr.sin6_scope_id); | ||
| 156 | std::memcpy(serialized_addr.sin6_addr, &addr.sin6_addr, | ||
| 157 | sizeof(SockAddrIn6::sin6_addr)); | ||
| 158 | std::memcpy(header_ptr + header_size, &serialized_addr, sizeof(SockAddrIn6)); | ||
| 159 | |||
| 160 | char addr_string_buf[64]{}; | ||
| 161 | inet_ntop(AF_INET6, &addr.sin6_addr, addr_string_buf, std::size(addr_string_buf)); | ||
| 162 | LOG_INFO(Service, "Resolved host '{}' to IPv6 address {}", host, addr_string_buf); | ||
| 163 | break; | ||
| 164 | } | ||
| 165 | default: | ||
| 166 | std::memcpy(header_ptr + header_size, current->ai_addr, addr_size); | ||
| 167 | break; | ||
| 168 | } | ||
| 169 | } | ||
| 170 | if (current->ai_canonname) { | ||
| 171 | std::memcpy(header_ptr + addr_size, current->ai_canonname, canonname_size); | ||
| 172 | } else { | 229 | } else { |
| 173 | *(header_ptr + header_size + addr_size) = 0; | 230 | data.push_back(0); |
| 174 | } | 231 | } |
| 175 | 232 | ||
| 176 | current = current->ai_next; | 233 | LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host, |
| 234 | Network::IPv4AddressToString(addrinfo.addr.ip)); | ||
| 177 | } | 235 | } |
| 178 | 236 | ||
| 179 | // 4-byte sentinel value | 237 | data.resize(data.size() + 4, 0); // 4-byte sentinel value |
| 180 | data.push_back(0); | ||
| 181 | data.push_back(0); | ||
| 182 | data.push_back(0); | ||
| 183 | data.push_back(0); | ||
| 184 | 238 | ||
| 185 | return data; | 239 | return data; |
| 186 | } | 240 | } |
| 187 | 241 | ||
| 188 | static std::pair<u32, s32> GetAddrInfoRequestImpl(HLERequestContext& ctx) { | 242 | static std::pair<u32, GetAddrInfoError> GetAddrInfoRequestImpl(HLERequestContext& ctx) { |
| 189 | struct Parameters { | 243 | struct InputParameters { |
| 190 | u8 use_nsd_resolve; | 244 | u8 use_nsd_resolve; |
| 191 | u32 unknown; | 245 | u32 cancel_handle; |
| 192 | u64 process_id; | 246 | u64 process_id; |
| 193 | }; | 247 | }; |
| 248 | static_assert(sizeof(InputParameters) == 0x10); | ||
| 194 | 249 | ||
| 195 | IPC::RequestParser rp{ctx}; | 250 | IPC::RequestParser rp{ctx}; |
| 196 | const auto parameters = rp.PopRaw<Parameters>(); | 251 | const auto parameters = rp.PopRaw<InputParameters>(); |
| 252 | |||
| 253 | LOG_WARNING( | ||
| 254 | Service, | ||
| 255 | "called with ignored parameters: use_nsd_resolve={}, cancel_handle={}, process_id={}", | ||
| 256 | parameters.use_nsd_resolve, parameters.cancel_handle, parameters.process_id); | ||
| 197 | 257 | ||
| 198 | LOG_WARNING(Service, | 258 | // TODO: If use_nsd_resolve is true, pass the name through NSD::Resolve |
| 199 | "called with ignored parameters: use_nsd_resolve={}, unknown={}, process_id={}", | 259 | // before looking up. |
| 200 | parameters.use_nsd_resolve, parameters.unknown, parameters.process_id); | ||
| 201 | 260 | ||
| 202 | const auto host_buffer = ctx.ReadBuffer(0); | 261 | const auto host_buffer = ctx.ReadBuffer(0); |
| 203 | const std::string host = Common::StringFromBuffer(host_buffer); | 262 | const std::string host = Common::StringFromBuffer(host_buffer); |
| 204 | 263 | ||
| 205 | const auto service_buffer = ctx.ReadBuffer(1); | 264 | std::optional<std::string> service = std::nullopt; |
| 206 | const std::string service = Common::StringFromBuffer(service_buffer); | 265 | if (ctx.CanReadBuffer(1)) { |
| 207 | 266 | const std::span<const u8> service_buffer = ctx.ReadBuffer(1); | |
| 208 | addrinfo* addrinfo; | 267 | service = Common::StringFromBuffer(service_buffer); |
| 209 | // Pass null for hints. Serialized hints are also passed in a buffer, but are ignored for now | 268 | } |
| 210 | s32 result_code = getaddrinfo(host.c_str(), service.c_str(), nullptr, &addrinfo); | ||
| 211 | 269 | ||
| 212 | u32 data_size = 0; | 270 | // Serialized hints are also passed in a buffer, but are ignored for now. |
| 213 | if (result_code == 0 && addrinfo != nullptr) { | ||
| 214 | const std::vector<u8>& data = SerializeAddrInfo(addrinfo, result_code, host); | ||
| 215 | data_size = static_cast<u32>(data.size()); | ||
| 216 | freeaddrinfo(addrinfo); | ||
| 217 | 271 | ||
| 218 | ctx.WriteBuffer(data, 0); | 272 | auto res = Network::GetAddressInfo(host, service); |
| 273 | if (!res.has_value()) { | ||
| 274 | return {0, Translate(res.error())}; | ||
| 219 | } | 275 | } |
| 220 | 276 | ||
| 221 | return std::make_pair(data_size, result_code); | 277 | const std::vector<u8> data = SerializeAddrInfo(res.value(), host); |
| 278 | const u32 data_size = static_cast<u32>(data.size()); | ||
| 279 | ctx.WriteBuffer(data, 0); | ||
| 280 | |||
| 281 | return {data_size, GetAddrInfoError::SUCCESS}; | ||
| 222 | } | 282 | } |
| 223 | 283 | ||
| 224 | void SFDNSRES::GetAddrInfoRequest(HLERequestContext& ctx) { | 284 | void SFDNSRES::GetAddrInfoRequest(HLERequestContext& ctx) { |
| 225 | auto [data_size, result_code] = GetAddrInfoRequestImpl(ctx); | 285 | auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx); |
| 286 | |||
| 287 | struct OutputParameters { | ||
| 288 | Errno bsd_errno; | ||
| 289 | GetAddrInfoError gai_error; | ||
| 290 | u32 data_size; | ||
| 291 | }; | ||
| 292 | static_assert(sizeof(OutputParameters) == 0xc); | ||
| 226 | 293 | ||
| 227 | IPC::ResponseBuilder rb{ctx, 4}; | 294 | IPC::ResponseBuilder rb{ctx, 5}; |
| 228 | rb.Push(ResultSuccess); | 295 | rb.Push(ResultSuccess); |
| 229 | rb.Push(static_cast<s32>(AddrInfoErrorToNetDbError(result_code))); // NetDBErrorCode | 296 | rb.PushRaw(OutputParameters{ |
| 230 | rb.Push(result_code); // errno | 297 | .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err), |
| 231 | rb.Push(data_size); // serialized size | 298 | .gai_error = emu_gai_err, |
| 299 | .data_size = data_size, | ||
| 300 | }); | ||
| 232 | } | 301 | } |
| 233 | 302 | ||
| 234 | void SFDNSRES::GetAddrInfoRequestWithOptions(HLERequestContext& ctx) { | 303 | void SFDNSRES::GetAddrInfoRequestWithOptions(HLERequestContext& ctx) { |
| 235 | // Additional options are ignored | 304 | // Additional options are ignored |
| 236 | auto [data_size, result_code] = GetAddrInfoRequestImpl(ctx); | 305 | auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx); |
| 306 | |||
| 307 | struct OutputParameters { | ||
| 308 | u32 data_size; | ||
| 309 | GetAddrInfoError gai_error; | ||
| 310 | NetDbError netdb_error; | ||
| 311 | Errno bsd_errno; | ||
| 312 | }; | ||
| 313 | static_assert(sizeof(OutputParameters) == 0x10); | ||
| 314 | |||
| 315 | IPC::ResponseBuilder rb{ctx, 6}; | ||
| 316 | rb.Push(ResultSuccess); | ||
| 317 | rb.PushRaw(OutputParameters{ | ||
| 318 | .data_size = data_size, | ||
| 319 | .gai_error = emu_gai_err, | ||
| 320 | .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err), | ||
| 321 | .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err), | ||
| 322 | }); | ||
| 323 | } | ||
| 324 | |||
| 325 | void SFDNSRES::ResolverSetOptionRequest(HLERequestContext& ctx) { | ||
| 326 | LOG_WARNING(Service, "(STUBBED) called"); | ||
| 327 | |||
| 328 | IPC::ResponseBuilder rb{ctx, 3}; | ||
| 237 | 329 | ||
| 238 | IPC::ResponseBuilder rb{ctx, 5}; | ||
| 239 | rb.Push(ResultSuccess); | 330 | rb.Push(ResultSuccess); |
| 240 | rb.Push(data_size); // serialized size | 331 | rb.Push<s32>(0); // bsd errno |
| 241 | rb.Push(result_code); // errno | ||
| 242 | rb.Push(static_cast<s32>(AddrInfoErrorToNetDbError(result_code))); // NetDBErrorCode | ||
| 243 | rb.Push(0); | ||
| 244 | } | 332 | } |
| 245 | 333 | ||
| 246 | } // namespace Service::Sockets | 334 | } // namespace Service::Sockets |
diff --git a/src/core/hle/service/sockets/sfdnsres.h b/src/core/hle/service/sockets/sfdnsres.h index 18e3cd60c..d99a9d560 100644 --- a/src/core/hle/service/sockets/sfdnsres.h +++ b/src/core/hle/service/sockets/sfdnsres.h | |||
| @@ -17,8 +17,11 @@ public: | |||
| 17 | ~SFDNSRES() override; | 17 | ~SFDNSRES() override; |
| 18 | 18 | ||
| 19 | private: | 19 | private: |
| 20 | void GetHostByNameRequest(HLERequestContext& ctx); | ||
| 21 | void GetHostByNameRequestWithOptions(HLERequestContext& ctx); | ||
| 20 | void GetAddrInfoRequest(HLERequestContext& ctx); | 22 | void GetAddrInfoRequest(HLERequestContext& ctx); |
| 21 | void GetAddrInfoRequestWithOptions(HLERequestContext& ctx); | 23 | void GetAddrInfoRequestWithOptions(HLERequestContext& ctx); |
| 24 | void ResolverSetOptionRequest(HLERequestContext& ctx); | ||
| 22 | }; | 25 | }; |
| 23 | 26 | ||
| 24 | } // namespace Service::Sockets | 27 | } // namespace Service::Sockets |
diff --git a/src/core/hle/service/sockets/sockets.h b/src/core/hle/service/sockets/sockets.h index acd2dae7b..77426c46e 100644 --- a/src/core/hle/service/sockets/sockets.h +++ b/src/core/hle/service/sockets/sockets.h | |||
| @@ -22,13 +22,35 @@ enum class Errno : u32 { | |||
| 22 | CONNRESET = 104, | 22 | CONNRESET = 104, |
| 23 | NOTCONN = 107, | 23 | NOTCONN = 107, |
| 24 | TIMEDOUT = 110, | 24 | TIMEDOUT = 110, |
| 25 | INPROGRESS = 115, | ||
| 26 | }; | ||
| 27 | |||
| 28 | enum class GetAddrInfoError : s32 { | ||
| 29 | SUCCESS = 0, | ||
| 30 | ADDRFAMILY = 1, | ||
| 31 | AGAIN = 2, | ||
| 32 | BADFLAGS = 3, | ||
| 33 | FAIL = 4, | ||
| 34 | FAMILY = 5, | ||
| 35 | MEMORY = 6, | ||
| 36 | NODATA = 7, | ||
| 37 | NONAME = 8, | ||
| 38 | SERVICE = 9, | ||
| 39 | SOCKTYPE = 10, | ||
| 40 | SYSTEM = 11, | ||
| 41 | BADHINTS = 12, | ||
| 42 | PROTOCOL = 13, | ||
| 43 | OVERFLOW_ = 14, // avoid name collision with Windows macro | ||
| 44 | OTHER = 15, | ||
| 25 | }; | 45 | }; |
| 26 | 46 | ||
| 27 | enum class Domain : u32 { | 47 | enum class Domain : u32 { |
| 48 | Unspecified = 0, | ||
| 28 | INET = 2, | 49 | INET = 2, |
| 29 | }; | 50 | }; |
| 30 | 51 | ||
| 31 | enum class Type : u32 { | 52 | enum class Type : u32 { |
| 53 | Unspecified = 0, | ||
| 32 | STREAM = 1, | 54 | STREAM = 1, |
| 33 | DGRAM = 2, | 55 | DGRAM = 2, |
| 34 | RAW = 3, | 56 | RAW = 3, |
| @@ -36,12 +58,16 @@ enum class Type : u32 { | |||
| 36 | }; | 58 | }; |
| 37 | 59 | ||
| 38 | enum class Protocol : u32 { | 60 | enum class Protocol : u32 { |
| 39 | UNSPECIFIED = 0, | 61 | Unspecified = 0, |
| 40 | ICMP = 1, | 62 | ICMP = 1, |
| 41 | TCP = 6, | 63 | TCP = 6, |
| 42 | UDP = 17, | 64 | UDP = 17, |
| 43 | }; | 65 | }; |
| 44 | 66 | ||
| 67 | enum class SocketLevel : u32 { | ||
| 68 | SOCKET = 0xffff, // i.e. SOL_SOCKET | ||
| 69 | }; | ||
| 70 | |||
| 45 | enum class OptName : u32 { | 71 | enum class OptName : u32 { |
| 46 | REUSEADDR = 0x4, | 72 | REUSEADDR = 0x4, |
| 47 | KEEPALIVE = 0x8, | 73 | KEEPALIVE = 0x8, |
| @@ -51,6 +77,8 @@ enum class OptName : u32 { | |||
| 51 | RCVBUF = 0x1002, | 77 | RCVBUF = 0x1002, |
| 52 | SNDTIMEO = 0x1005, | 78 | SNDTIMEO = 0x1005, |
| 53 | RCVTIMEO = 0x1006, | 79 | RCVTIMEO = 0x1006, |
| 80 | ERROR_ = 0x1007, // avoid name collision with Windows macro | ||
| 81 | NOSIGPIPE = 0x800, // at least according to libnx | ||
| 54 | }; | 82 | }; |
| 55 | 83 | ||
| 56 | enum class ShutdownHow : s32 { | 84 | enum class ShutdownHow : s32 { |
| @@ -80,6 +108,9 @@ enum class PollEvents : u16 { | |||
| 80 | Err = 1 << 3, | 108 | Err = 1 << 3, |
| 81 | Hup = 1 << 4, | 109 | Hup = 1 << 4, |
| 82 | Nval = 1 << 5, | 110 | Nval = 1 << 5, |
| 111 | RdNorm = 1 << 6, | ||
| 112 | RdBand = 1 << 7, | ||
| 113 | WrBand = 1 << 8, | ||
| 83 | }; | 114 | }; |
| 84 | 115 | ||
| 85 | DECLARE_ENUM_FLAG_OPERATORS(PollEvents); | 116 | DECLARE_ENUM_FLAG_OPERATORS(PollEvents); |
diff --git a/src/core/hle/service/sockets/sockets_translate.cpp b/src/core/hle/service/sockets/sockets_translate.cpp index 594e58f90..2f9a0e39c 100644 --- a/src/core/hle/service/sockets/sockets_translate.cpp +++ b/src/core/hle/service/sockets/sockets_translate.cpp | |||
| @@ -29,6 +29,8 @@ Errno Translate(Network::Errno value) { | |||
| 29 | return Errno::TIMEDOUT; | 29 | return Errno::TIMEDOUT; |
| 30 | case Network::Errno::CONNRESET: | 30 | case Network::Errno::CONNRESET: |
| 31 | return Errno::CONNRESET; | 31 | return Errno::CONNRESET; |
| 32 | case Network::Errno::INPROGRESS: | ||
| 33 | return Errno::INPROGRESS; | ||
| 32 | default: | 34 | default: |
| 33 | UNIMPLEMENTED_MSG("Unimplemented errno={}", value); | 35 | UNIMPLEMENTED_MSG("Unimplemented errno={}", value); |
| 34 | return Errno::SUCCESS; | 36 | return Errno::SUCCESS; |
| @@ -39,8 +41,50 @@ std::pair<s32, Errno> Translate(std::pair<s32, Network::Errno> value) { | |||
| 39 | return {value.first, Translate(value.second)}; | 41 | return {value.first, Translate(value.second)}; |
| 40 | } | 42 | } |
| 41 | 43 | ||
| 44 | GetAddrInfoError Translate(Network::GetAddrInfoError error) { | ||
| 45 | switch (error) { | ||
| 46 | case Network::GetAddrInfoError::SUCCESS: | ||
| 47 | return GetAddrInfoError::SUCCESS; | ||
| 48 | case Network::GetAddrInfoError::ADDRFAMILY: | ||
| 49 | return GetAddrInfoError::ADDRFAMILY; | ||
| 50 | case Network::GetAddrInfoError::AGAIN: | ||
| 51 | return GetAddrInfoError::AGAIN; | ||
| 52 | case Network::GetAddrInfoError::BADFLAGS: | ||
| 53 | return GetAddrInfoError::BADFLAGS; | ||
| 54 | case Network::GetAddrInfoError::FAIL: | ||
| 55 | return GetAddrInfoError::FAIL; | ||
| 56 | case Network::GetAddrInfoError::FAMILY: | ||
| 57 | return GetAddrInfoError::FAMILY; | ||
| 58 | case Network::GetAddrInfoError::MEMORY: | ||
| 59 | return GetAddrInfoError::MEMORY; | ||
| 60 | case Network::GetAddrInfoError::NODATA: | ||
| 61 | return GetAddrInfoError::NODATA; | ||
| 62 | case Network::GetAddrInfoError::NONAME: | ||
| 63 | return GetAddrInfoError::NONAME; | ||
| 64 | case Network::GetAddrInfoError::SERVICE: | ||
| 65 | return GetAddrInfoError::SERVICE; | ||
| 66 | case Network::GetAddrInfoError::SOCKTYPE: | ||
| 67 | return GetAddrInfoError::SOCKTYPE; | ||
| 68 | case Network::GetAddrInfoError::SYSTEM: | ||
| 69 | return GetAddrInfoError::SYSTEM; | ||
| 70 | case Network::GetAddrInfoError::BADHINTS: | ||
| 71 | return GetAddrInfoError::BADHINTS; | ||
| 72 | case Network::GetAddrInfoError::PROTOCOL: | ||
| 73 | return GetAddrInfoError::PROTOCOL; | ||
| 74 | case Network::GetAddrInfoError::OVERFLOW_: | ||
| 75 | return GetAddrInfoError::OVERFLOW_; | ||
| 76 | case Network::GetAddrInfoError::OTHER: | ||
| 77 | return GetAddrInfoError::OTHER; | ||
| 78 | default: | ||
| 79 | UNIMPLEMENTED_MSG("Unimplemented GetAddrInfoError={}", error); | ||
| 80 | return GetAddrInfoError::OTHER; | ||
| 81 | } | ||
| 82 | } | ||
| 83 | |||
| 42 | Network::Domain Translate(Domain domain) { | 84 | Network::Domain Translate(Domain domain) { |
| 43 | switch (domain) { | 85 | switch (domain) { |
| 86 | case Domain::Unspecified: | ||
| 87 | return Network::Domain::Unspecified; | ||
| 44 | case Domain::INET: | 88 | case Domain::INET: |
| 45 | return Network::Domain::INET; | 89 | return Network::Domain::INET; |
| 46 | default: | 90 | default: |
| @@ -51,6 +95,8 @@ Network::Domain Translate(Domain domain) { | |||
| 51 | 95 | ||
| 52 | Domain Translate(Network::Domain domain) { | 96 | Domain Translate(Network::Domain domain) { |
| 53 | switch (domain) { | 97 | switch (domain) { |
| 98 | case Network::Domain::Unspecified: | ||
| 99 | return Domain::Unspecified; | ||
| 54 | case Network::Domain::INET: | 100 | case Network::Domain::INET: |
| 55 | return Domain::INET; | 101 | return Domain::INET; |
| 56 | default: | 102 | default: |
| @@ -61,39 +107,69 @@ Domain Translate(Network::Domain domain) { | |||
| 61 | 107 | ||
| 62 | Network::Type Translate(Type type) { | 108 | Network::Type Translate(Type type) { |
| 63 | switch (type) { | 109 | switch (type) { |
| 110 | case Type::Unspecified: | ||
| 111 | return Network::Type::Unspecified; | ||
| 64 | case Type::STREAM: | 112 | case Type::STREAM: |
| 65 | return Network::Type::STREAM; | 113 | return Network::Type::STREAM; |
| 66 | case Type::DGRAM: | 114 | case Type::DGRAM: |
| 67 | return Network::Type::DGRAM; | 115 | return Network::Type::DGRAM; |
| 116 | case Type::RAW: | ||
| 117 | return Network::Type::RAW; | ||
| 118 | case Type::SEQPACKET: | ||
| 119 | return Network::Type::SEQPACKET; | ||
| 68 | default: | 120 | default: |
| 69 | UNIMPLEMENTED_MSG("Unimplemented type={}", type); | 121 | UNIMPLEMENTED_MSG("Unimplemented type={}", type); |
| 70 | return Network::Type{}; | 122 | return Network::Type{}; |
| 71 | } | 123 | } |
| 72 | } | 124 | } |
| 73 | 125 | ||
| 74 | Network::Protocol Translate(Type type, Protocol protocol) { | 126 | Type Translate(Network::Type type) { |
| 127 | switch (type) { | ||
| 128 | case Network::Type::Unspecified: | ||
| 129 | return Type::Unspecified; | ||
| 130 | case Network::Type::STREAM: | ||
| 131 | return Type::STREAM; | ||
| 132 | case Network::Type::DGRAM: | ||
| 133 | return Type::DGRAM; | ||
| 134 | case Network::Type::RAW: | ||
| 135 | return Type::RAW; | ||
| 136 | case Network::Type::SEQPACKET: | ||
| 137 | return Type::SEQPACKET; | ||
| 138 | default: | ||
| 139 | UNIMPLEMENTED_MSG("Unimplemented type={}", type); | ||
| 140 | return Type{}; | ||
| 141 | } | ||
| 142 | } | ||
| 143 | |||
| 144 | Network::Protocol Translate(Protocol protocol) { | ||
| 75 | switch (protocol) { | 145 | switch (protocol) { |
| 76 | case Protocol::UNSPECIFIED: | 146 | case Protocol::Unspecified: |
| 77 | LOG_WARNING(Service, "Unspecified protocol, assuming protocol from type"); | 147 | return Network::Protocol::Unspecified; |
| 78 | switch (type) { | ||
| 79 | case Type::DGRAM: | ||
| 80 | return Network::Protocol::UDP; | ||
| 81 | case Type::STREAM: | ||
| 82 | return Network::Protocol::TCP; | ||
| 83 | default: | ||
| 84 | return Network::Protocol::TCP; | ||
| 85 | } | ||
| 86 | case Protocol::TCP: | 148 | case Protocol::TCP: |
| 87 | return Network::Protocol::TCP; | 149 | return Network::Protocol::TCP; |
| 88 | case Protocol::UDP: | 150 | case Protocol::UDP: |
| 89 | return Network::Protocol::UDP; | 151 | return Network::Protocol::UDP; |
| 90 | default: | 152 | default: |
| 91 | UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol); | 153 | UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol); |
| 92 | return Network::Protocol::TCP; | 154 | return Network::Protocol::Unspecified; |
| 155 | } | ||
| 156 | } | ||
| 157 | |||
| 158 | Protocol Translate(Network::Protocol protocol) { | ||
| 159 | switch (protocol) { | ||
| 160 | case Network::Protocol::Unspecified: | ||
| 161 | return Protocol::Unspecified; | ||
| 162 | case Network::Protocol::TCP: | ||
| 163 | return Protocol::TCP; | ||
| 164 | case Network::Protocol::UDP: | ||
| 165 | return Protocol::UDP; | ||
| 166 | default: | ||
| 167 | UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol); | ||
| 168 | return Protocol::Unspecified; | ||
| 93 | } | 169 | } |
| 94 | } | 170 | } |
| 95 | 171 | ||
| 96 | Network::PollEvents TranslatePollEventsToHost(PollEvents flags) { | 172 | Network::PollEvents Translate(PollEvents flags) { |
| 97 | Network::PollEvents result{}; | 173 | Network::PollEvents result{}; |
| 98 | const auto translate = [&result, &flags](PollEvents from, Network::PollEvents to) { | 174 | const auto translate = [&result, &flags](PollEvents from, Network::PollEvents to) { |
| 99 | if (True(flags & from)) { | 175 | if (True(flags & from)) { |
| @@ -107,12 +183,15 @@ Network::PollEvents TranslatePollEventsToHost(PollEvents flags) { | |||
| 107 | translate(PollEvents::Err, Network::PollEvents::Err); | 183 | translate(PollEvents::Err, Network::PollEvents::Err); |
| 108 | translate(PollEvents::Hup, Network::PollEvents::Hup); | 184 | translate(PollEvents::Hup, Network::PollEvents::Hup); |
| 109 | translate(PollEvents::Nval, Network::PollEvents::Nval); | 185 | translate(PollEvents::Nval, Network::PollEvents::Nval); |
| 186 | translate(PollEvents::RdNorm, Network::PollEvents::RdNorm); | ||
| 187 | translate(PollEvents::RdBand, Network::PollEvents::RdBand); | ||
| 188 | translate(PollEvents::WrBand, Network::PollEvents::WrBand); | ||
| 110 | 189 | ||
| 111 | UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags); | 190 | UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags); |
| 112 | return result; | 191 | return result; |
| 113 | } | 192 | } |
| 114 | 193 | ||
| 115 | PollEvents TranslatePollEventsToGuest(Network::PollEvents flags) { | 194 | PollEvents Translate(Network::PollEvents flags) { |
| 116 | PollEvents result{}; | 195 | PollEvents result{}; |
| 117 | const auto translate = [&result, &flags](Network::PollEvents from, PollEvents to) { | 196 | const auto translate = [&result, &flags](Network::PollEvents from, PollEvents to) { |
| 118 | if (True(flags & from)) { | 197 | if (True(flags & from)) { |
| @@ -127,13 +206,18 @@ PollEvents TranslatePollEventsToGuest(Network::PollEvents flags) { | |||
| 127 | translate(Network::PollEvents::Err, PollEvents::Err); | 206 | translate(Network::PollEvents::Err, PollEvents::Err); |
| 128 | translate(Network::PollEvents::Hup, PollEvents::Hup); | 207 | translate(Network::PollEvents::Hup, PollEvents::Hup); |
| 129 | translate(Network::PollEvents::Nval, PollEvents::Nval); | 208 | translate(Network::PollEvents::Nval, PollEvents::Nval); |
| 209 | translate(Network::PollEvents::RdNorm, PollEvents::RdNorm); | ||
| 210 | translate(Network::PollEvents::RdBand, PollEvents::RdBand); | ||
| 211 | translate(Network::PollEvents::WrBand, PollEvents::WrBand); | ||
| 130 | 212 | ||
| 131 | UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags); | 213 | UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags); |
| 132 | return result; | 214 | return result; |
| 133 | } | 215 | } |
| 134 | 216 | ||
| 135 | Network::SockAddrIn Translate(SockAddrIn value) { | 217 | Network::SockAddrIn Translate(SockAddrIn value) { |
| 136 | ASSERT(value.len == 0 || value.len == sizeof(value)); | 218 | // Note: 6 is incorrect, but can be passed by homebrew (because libnx sets |
| 219 | // sin_len to 6 when deserializing getaddrinfo results). | ||
| 220 | ASSERT(value.len == 0 || value.len == sizeof(value) || value.len == 6); | ||
| 137 | 221 | ||
| 138 | return { | 222 | return { |
| 139 | .family = Translate(static_cast<Domain>(value.family)), | 223 | .family = Translate(static_cast<Domain>(value.family)), |
diff --git a/src/core/hle/service/sockets/sockets_translate.h b/src/core/hle/service/sockets/sockets_translate.h index c93291d3e..694868b37 100644 --- a/src/core/hle/service/sockets/sockets_translate.h +++ b/src/core/hle/service/sockets/sockets_translate.h | |||
| @@ -17,6 +17,9 @@ Errno Translate(Network::Errno value); | |||
| 17 | /// Translate abstract return value errno pair to guest return value errno pair | 17 | /// Translate abstract return value errno pair to guest return value errno pair |
| 18 | std::pair<s32, Errno> Translate(std::pair<s32, Network::Errno> value); | 18 | std::pair<s32, Errno> Translate(std::pair<s32, Network::Errno> value); |
| 19 | 19 | ||
| 20 | /// Translate abstract getaddrinfo error to guest getaddrinfo error | ||
| 21 | GetAddrInfoError Translate(Network::GetAddrInfoError value); | ||
| 22 | |||
| 20 | /// Translate guest domain to abstract domain | 23 | /// Translate guest domain to abstract domain |
| 21 | Network::Domain Translate(Domain domain); | 24 | Network::Domain Translate(Domain domain); |
| 22 | 25 | ||
| @@ -26,14 +29,20 @@ Domain Translate(Network::Domain domain); | |||
| 26 | /// Translate guest type to abstract type | 29 | /// Translate guest type to abstract type |
| 27 | Network::Type Translate(Type type); | 30 | Network::Type Translate(Type type); |
| 28 | 31 | ||
| 32 | /// Translate abstract type to guest type | ||
| 33 | Type Translate(Network::Type type); | ||
| 34 | |||
| 29 | /// Translate guest protocol to abstract protocol | 35 | /// Translate guest protocol to abstract protocol |
| 30 | Network::Protocol Translate(Type type, Protocol protocol); | 36 | Network::Protocol Translate(Protocol protocol); |
| 31 | 37 | ||
| 32 | /// Translate abstract poll event flags to guest poll event flags | 38 | /// Translate abstract protocol to guest protocol |
| 33 | Network::PollEvents TranslatePollEventsToHost(PollEvents flags); | 39 | Protocol Translate(Network::Protocol protocol); |
| 34 | 40 | ||
| 35 | /// Translate guest poll event flags to abstract poll event flags | 41 | /// Translate guest poll event flags to abstract poll event flags |
| 36 | PollEvents TranslatePollEventsToGuest(Network::PollEvents flags); | 42 | Network::PollEvents Translate(PollEvents flags); |
| 43 | |||
| 44 | /// Translate abstract poll event flags to guest poll event flags | ||
| 45 | PollEvents Translate(Network::PollEvents flags); | ||
| 37 | 46 | ||
| 38 | /// Translate guest socket address structure to abstract socket address structure | 47 | /// Translate guest socket address structure to abstract socket address structure |
| 39 | Network::SockAddrIn Translate(SockAddrIn value); | 48 | Network::SockAddrIn Translate(SockAddrIn value); |
diff --git a/src/core/hle/service/ssl/ssl.cpp b/src/core/hle/service/ssl/ssl.cpp index 2b99dd7ac..9c96f9763 100644 --- a/src/core/hle/service/ssl/ssl.cpp +++ b/src/core/hle/service/ssl/ssl.cpp | |||
| @@ -1,10 +1,18 @@ | |||
| 1 | // SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project | 1 | // SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project |
| 2 | // SPDX-License-Identifier: GPL-2.0-or-later | 2 | // SPDX-License-Identifier: GPL-2.0-or-later |
| 3 | 3 | ||
| 4 | #include "common/string_util.h" | ||
| 5 | |||
| 6 | #include "core/core.h" | ||
| 4 | #include "core/hle/service/ipc_helpers.h" | 7 | #include "core/hle/service/ipc_helpers.h" |
| 5 | #include "core/hle/service/server_manager.h" | 8 | #include "core/hle/service/server_manager.h" |
| 6 | #include "core/hle/service/service.h" | 9 | #include "core/hle/service/service.h" |
| 10 | #include "core/hle/service/sm/sm.h" | ||
| 11 | #include "core/hle/service/sockets/bsd.h" | ||
| 7 | #include "core/hle/service/ssl/ssl.h" | 12 | #include "core/hle/service/ssl/ssl.h" |
| 13 | #include "core/hle/service/ssl/ssl_backend.h" | ||
| 14 | #include "core/internal_network/network.h" | ||
| 15 | #include "core/internal_network/sockets.h" | ||
| 8 | 16 | ||
| 9 | namespace Service::SSL { | 17 | namespace Service::SSL { |
| 10 | 18 | ||
| @@ -20,6 +28,18 @@ enum class ContextOption : u32 { | |||
| 20 | CrlImportDateCheckEnable = 1, | 28 | CrlImportDateCheckEnable = 1, |
| 21 | }; | 29 | }; |
| 22 | 30 | ||
| 31 | // This is nn::ssl::Connection::IoMode | ||
| 32 | enum class IoMode : u32 { | ||
| 33 | Blocking = 1, | ||
| 34 | NonBlocking = 2, | ||
| 35 | }; | ||
| 36 | |||
| 37 | // This is nn::ssl::sf::OptionType | ||
| 38 | enum class OptionType : u32 { | ||
| 39 | DoNotCloseSocket = 0, | ||
| 40 | GetServerCertChain = 1, | ||
| 41 | }; | ||
| 42 | |||
| 23 | // This is nn::ssl::sf::SslVersion | 43 | // This is nn::ssl::sf::SslVersion |
| 24 | struct SslVersion { | 44 | struct SslVersion { |
| 25 | union { | 45 | union { |
| @@ -34,35 +54,42 @@ struct SslVersion { | |||
| 34 | }; | 54 | }; |
| 35 | }; | 55 | }; |
| 36 | 56 | ||
| 57 | struct SslContextSharedData { | ||
| 58 | u32 connection_count = 0; | ||
| 59 | }; | ||
| 60 | |||
| 37 | class ISslConnection final : public ServiceFramework<ISslConnection> { | 61 | class ISslConnection final : public ServiceFramework<ISslConnection> { |
| 38 | public: | 62 | public: |
| 39 | explicit ISslConnection(Core::System& system_, SslVersion version) | 63 | explicit ISslConnection(Core::System& system_in, SslVersion ssl_version_in, |
| 40 | : ServiceFramework{system_, "ISslConnection"}, ssl_version{version} { | 64 | std::shared_ptr<SslContextSharedData>& shared_data_in, |
| 65 | std::unique_ptr<SSLConnectionBackend>&& backend_in) | ||
| 66 | : ServiceFramework{system_in, "ISslConnection"}, ssl_version{ssl_version_in}, | ||
| 67 | shared_data{shared_data_in}, backend{std::move(backend_in)} { | ||
| 41 | // clang-format off | 68 | // clang-format off |
| 42 | static const FunctionInfo functions[] = { | 69 | static const FunctionInfo functions[] = { |
| 43 | {0, nullptr, "SetSocketDescriptor"}, | 70 | {0, &ISslConnection::SetSocketDescriptor, "SetSocketDescriptor"}, |
| 44 | {1, nullptr, "SetHostName"}, | 71 | {1, &ISslConnection::SetHostName, "SetHostName"}, |
| 45 | {2, nullptr, "SetVerifyOption"}, | 72 | {2, &ISslConnection::SetVerifyOption, "SetVerifyOption"}, |
| 46 | {3, nullptr, "SetIoMode"}, | 73 | {3, &ISslConnection::SetIoMode, "SetIoMode"}, |
| 47 | {4, nullptr, "GetSocketDescriptor"}, | 74 | {4, nullptr, "GetSocketDescriptor"}, |
| 48 | {5, nullptr, "GetHostName"}, | 75 | {5, nullptr, "GetHostName"}, |
| 49 | {6, nullptr, "GetVerifyOption"}, | 76 | {6, nullptr, "GetVerifyOption"}, |
| 50 | {7, nullptr, "GetIoMode"}, | 77 | {7, nullptr, "GetIoMode"}, |
| 51 | {8, nullptr, "DoHandshake"}, | 78 | {8, &ISslConnection::DoHandshake, "DoHandshake"}, |
| 52 | {9, nullptr, "DoHandshakeGetServerCert"}, | 79 | {9, &ISslConnection::DoHandshakeGetServerCert, "DoHandshakeGetServerCert"}, |
| 53 | {10, nullptr, "Read"}, | 80 | {10, &ISslConnection::Read, "Read"}, |
| 54 | {11, nullptr, "Write"}, | 81 | {11, &ISslConnection::Write, "Write"}, |
| 55 | {12, nullptr, "Pending"}, | 82 | {12, &ISslConnection::Pending, "Pending"}, |
| 56 | {13, nullptr, "Peek"}, | 83 | {13, nullptr, "Peek"}, |
| 57 | {14, nullptr, "Poll"}, | 84 | {14, nullptr, "Poll"}, |
| 58 | {15, nullptr, "GetVerifyCertError"}, | 85 | {15, nullptr, "GetVerifyCertError"}, |
| 59 | {16, nullptr, "GetNeededServerCertBufferSize"}, | 86 | {16, nullptr, "GetNeededServerCertBufferSize"}, |
| 60 | {17, nullptr, "SetSessionCacheMode"}, | 87 | {17, &ISslConnection::SetSessionCacheMode, "SetSessionCacheMode"}, |
| 61 | {18, nullptr, "GetSessionCacheMode"}, | 88 | {18, nullptr, "GetSessionCacheMode"}, |
| 62 | {19, nullptr, "FlushSessionCache"}, | 89 | {19, nullptr, "FlushSessionCache"}, |
| 63 | {20, nullptr, "SetRenegotiationMode"}, | 90 | {20, nullptr, "SetRenegotiationMode"}, |
| 64 | {21, nullptr, "GetRenegotiationMode"}, | 91 | {21, nullptr, "GetRenegotiationMode"}, |
| 65 | {22, nullptr, "SetOption"}, | 92 | {22, &ISslConnection::SetOption, "SetOption"}, |
| 66 | {23, nullptr, "GetOption"}, | 93 | {23, nullptr, "GetOption"}, |
| 67 | {24, nullptr, "GetVerifyCertErrors"}, | 94 | {24, nullptr, "GetVerifyCertErrors"}, |
| 68 | {25, nullptr, "GetCipherInfo"}, | 95 | {25, nullptr, "GetCipherInfo"}, |
| @@ -80,21 +107,299 @@ public: | |||
| 80 | // clang-format on | 107 | // clang-format on |
| 81 | 108 | ||
| 82 | RegisterHandlers(functions); | 109 | RegisterHandlers(functions); |
| 110 | |||
| 111 | shared_data->connection_count++; | ||
| 112 | } | ||
| 113 | |||
| 114 | ~ISslConnection() { | ||
| 115 | shared_data->connection_count--; | ||
| 116 | if (fd_to_close.has_value()) { | ||
| 117 | const s32 fd = *fd_to_close; | ||
| 118 | if (!do_not_close_socket) { | ||
| 119 | LOG_ERROR(Service_SSL, | ||
| 120 | "do_not_close_socket was changed after setting socket; is this right?"); | ||
| 121 | } else { | ||
| 122 | auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u"); | ||
| 123 | if (bsd) { | ||
| 124 | auto err = bsd->CloseImpl(fd); | ||
| 125 | if (err != Service::Sockets::Errno::SUCCESS) { | ||
| 126 | LOG_ERROR(Service_SSL, "Failed to close duplicated socket: {}", err); | ||
| 127 | } | ||
| 128 | } | ||
| 129 | } | ||
| 130 | } | ||
| 83 | } | 131 | } |
| 84 | 132 | ||
| 85 | private: | 133 | private: |
| 86 | SslVersion ssl_version; | 134 | SslVersion ssl_version; |
| 135 | std::shared_ptr<SslContextSharedData> shared_data; | ||
| 136 | std::unique_ptr<SSLConnectionBackend> backend; | ||
| 137 | std::optional<int> fd_to_close; | ||
| 138 | bool do_not_close_socket = false; | ||
| 139 | bool get_server_cert_chain = false; | ||
| 140 | std::shared_ptr<Network::SocketBase> socket; | ||
| 141 | bool did_set_host_name = false; | ||
| 142 | bool did_handshake = false; | ||
| 143 | |||
| 144 | ResultVal<s32> SetSocketDescriptorImpl(s32 fd) { | ||
| 145 | LOG_DEBUG(Service_SSL, "called, fd={}", fd); | ||
| 146 | ASSERT(!did_handshake); | ||
| 147 | auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u"); | ||
| 148 | ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; }); | ||
| 149 | s32 ret_fd; | ||
| 150 | // Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor | ||
| 151 | if (do_not_close_socket) { | ||
| 152 | auto res = bsd->DuplicateSocketImpl(fd); | ||
| 153 | if (!res.has_value()) { | ||
| 154 | LOG_ERROR(Service_SSL, "Failed to duplicate socket with fd {}", fd); | ||
| 155 | return ResultInvalidSocket; | ||
| 156 | } | ||
| 157 | fd = *res; | ||
| 158 | fd_to_close = fd; | ||
| 159 | ret_fd = fd; | ||
| 160 | } else { | ||
| 161 | ret_fd = -1; | ||
| 162 | } | ||
| 163 | std::optional<std::shared_ptr<Network::SocketBase>> sock = bsd->GetSocket(fd); | ||
| 164 | if (!sock.has_value()) { | ||
| 165 | LOG_ERROR(Service_SSL, "invalid socket fd {}", fd); | ||
| 166 | return ResultInvalidSocket; | ||
| 167 | } | ||
| 168 | socket = std::move(*sock); | ||
| 169 | backend->SetSocket(socket); | ||
| 170 | return ret_fd; | ||
| 171 | } | ||
| 172 | |||
| 173 | Result SetHostNameImpl(const std::string& hostname) { | ||
| 174 | LOG_DEBUG(Service_SSL, "called. hostname={}", hostname); | ||
| 175 | ASSERT(!did_handshake); | ||
| 176 | Result res = backend->SetHostName(hostname); | ||
| 177 | if (res == ResultSuccess) { | ||
| 178 | did_set_host_name = true; | ||
| 179 | } | ||
| 180 | return res; | ||
| 181 | } | ||
| 182 | |||
| 183 | Result SetVerifyOptionImpl(u32 option) { | ||
| 184 | ASSERT(!did_handshake); | ||
| 185 | LOG_WARNING(Service_SSL, "(STUBBED) called. option={}", option); | ||
| 186 | return ResultSuccess; | ||
| 187 | } | ||
| 188 | |||
| 189 | Result SetIoModeImpl(u32 input_mode) { | ||
| 190 | auto mode = static_cast<IoMode>(input_mode); | ||
| 191 | ASSERT(mode == IoMode::Blocking || mode == IoMode::NonBlocking); | ||
| 192 | ASSERT_OR_EXECUTE(socket, { return ResultNoSocket; }); | ||
| 193 | |||
| 194 | const bool non_block = mode == IoMode::NonBlocking; | ||
| 195 | const Network::Errno error = socket->SetNonBlock(non_block); | ||
| 196 | if (error != Network::Errno::SUCCESS) { | ||
| 197 | LOG_ERROR(Service_SSL, "Failed to set native socket non-block flag to {}", non_block); | ||
| 198 | } | ||
| 199 | return ResultSuccess; | ||
| 200 | } | ||
| 201 | |||
| 202 | Result SetSessionCacheModeImpl(u32 mode) { | ||
| 203 | ASSERT(!did_handshake); | ||
| 204 | LOG_WARNING(Service_SSL, "(STUBBED) called. value={}", mode); | ||
| 205 | return ResultSuccess; | ||
| 206 | } | ||
| 207 | |||
| 208 | Result DoHandshakeImpl() { | ||
| 209 | ASSERT_OR_EXECUTE(!did_handshake && socket, { return ResultNoSocket; }); | ||
| 210 | ASSERT_OR_EXECUTE_MSG( | ||
| 211 | did_set_host_name, { return ResultInternalError; }, | ||
| 212 | "Expected SetHostName before DoHandshake"); | ||
| 213 | Result res = backend->DoHandshake(); | ||
| 214 | did_handshake = res.IsSuccess(); | ||
| 215 | return res; | ||
| 216 | } | ||
| 217 | |||
| 218 | std::vector<u8> SerializeServerCerts(const std::vector<std::vector<u8>>& certs) { | ||
| 219 | struct Header { | ||
| 220 | u64 magic; | ||
| 221 | u32 count; | ||
| 222 | u32 pad; | ||
| 223 | }; | ||
| 224 | struct EntryHeader { | ||
| 225 | u32 size; | ||
| 226 | u32 offset; | ||
| 227 | }; | ||
| 228 | if (!get_server_cert_chain) { | ||
| 229 | // Just return the first one, unencoded. | ||
| 230 | ASSERT_OR_EXECUTE_MSG( | ||
| 231 | !certs.empty(), { return {}; }, "Should be at least one server cert"); | ||
| 232 | return certs[0]; | ||
| 233 | } | ||
| 234 | std::vector<u8> ret; | ||
| 235 | Header header{0x4E4D684374726543, static_cast<u32>(certs.size()), 0}; | ||
| 236 | ret.insert(ret.end(), reinterpret_cast<u8*>(&header), reinterpret_cast<u8*>(&header + 1)); | ||
| 237 | size_t data_offset = sizeof(Header) + certs.size() * sizeof(EntryHeader); | ||
| 238 | for (auto& cert : certs) { | ||
| 239 | EntryHeader entry_header{static_cast<u32>(cert.size()), static_cast<u32>(data_offset)}; | ||
| 240 | data_offset += cert.size(); | ||
| 241 | ret.insert(ret.end(), reinterpret_cast<u8*>(&entry_header), | ||
| 242 | reinterpret_cast<u8*>(&entry_header + 1)); | ||
| 243 | } | ||
| 244 | for (auto& cert : certs) { | ||
| 245 | ret.insert(ret.end(), cert.begin(), cert.end()); | ||
| 246 | } | ||
| 247 | return ret; | ||
| 248 | } | ||
| 249 | |||
| 250 | ResultVal<std::vector<u8>> ReadImpl(size_t size) { | ||
| 251 | ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; }); | ||
| 252 | std::vector<u8> res(size); | ||
| 253 | ResultVal<size_t> actual = backend->Read(res); | ||
| 254 | if (actual.Failed()) { | ||
| 255 | return actual.Code(); | ||
| 256 | } | ||
| 257 | res.resize(*actual); | ||
| 258 | return res; | ||
| 259 | } | ||
| 260 | |||
| 261 | ResultVal<size_t> WriteImpl(std::span<const u8> data) { | ||
| 262 | ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; }); | ||
| 263 | return backend->Write(data); | ||
| 264 | } | ||
| 265 | |||
| 266 | ResultVal<s32> PendingImpl() { | ||
| 267 | LOG_WARNING(Service_SSL, "(STUBBED) called."); | ||
| 268 | return 0; | ||
| 269 | } | ||
| 270 | |||
| 271 | void SetSocketDescriptor(HLERequestContext& ctx) { | ||
| 272 | IPC::RequestParser rp{ctx}; | ||
| 273 | const s32 fd = rp.Pop<s32>(); | ||
| 274 | const ResultVal<s32> res = SetSocketDescriptorImpl(fd); | ||
| 275 | IPC::ResponseBuilder rb{ctx, 3}; | ||
| 276 | rb.Push(res.Code()); | ||
| 277 | rb.Push<s32>(res.ValueOr(-1)); | ||
| 278 | } | ||
| 279 | |||
| 280 | void SetHostName(HLERequestContext& ctx) { | ||
| 281 | const std::string hostname = Common::StringFromBuffer(ctx.ReadBuffer()); | ||
| 282 | const Result res = SetHostNameImpl(hostname); | ||
| 283 | IPC::ResponseBuilder rb{ctx, 2}; | ||
| 284 | rb.Push(res); | ||
| 285 | } | ||
| 286 | |||
| 287 | void SetVerifyOption(HLERequestContext& ctx) { | ||
| 288 | IPC::RequestParser rp{ctx}; | ||
| 289 | const u32 option = rp.Pop<u32>(); | ||
| 290 | const Result res = SetVerifyOptionImpl(option); | ||
| 291 | IPC::ResponseBuilder rb{ctx, 2}; | ||
| 292 | rb.Push(res); | ||
| 293 | } | ||
| 294 | |||
| 295 | void SetIoMode(HLERequestContext& ctx) { | ||
| 296 | IPC::RequestParser rp{ctx}; | ||
| 297 | const u32 mode = rp.Pop<u32>(); | ||
| 298 | const Result res = SetIoModeImpl(mode); | ||
| 299 | IPC::ResponseBuilder rb{ctx, 2}; | ||
| 300 | rb.Push(res); | ||
| 301 | } | ||
| 302 | |||
| 303 | void DoHandshake(HLERequestContext& ctx) { | ||
| 304 | const Result res = DoHandshakeImpl(); | ||
| 305 | IPC::ResponseBuilder rb{ctx, 2}; | ||
| 306 | rb.Push(res); | ||
| 307 | } | ||
| 308 | |||
| 309 | void DoHandshakeGetServerCert(HLERequestContext& ctx) { | ||
| 310 | struct OutputParameters { | ||
| 311 | u32 certs_size; | ||
| 312 | u32 certs_count; | ||
| 313 | }; | ||
| 314 | static_assert(sizeof(OutputParameters) == 0x8); | ||
| 315 | |||
| 316 | const Result res = DoHandshakeImpl(); | ||
| 317 | OutputParameters out{}; | ||
| 318 | if (res == ResultSuccess) { | ||
| 319 | auto certs = backend->GetServerCerts(); | ||
| 320 | if (certs.Succeeded()) { | ||
| 321 | const std::vector<u8> certs_buf = SerializeServerCerts(*certs); | ||
| 322 | ctx.WriteBuffer(certs_buf); | ||
| 323 | out.certs_count = static_cast<u32>(certs->size()); | ||
| 324 | out.certs_size = static_cast<u32>(certs_buf.size()); | ||
| 325 | } | ||
| 326 | } | ||
| 327 | IPC::ResponseBuilder rb{ctx, 4}; | ||
| 328 | rb.Push(res); | ||
| 329 | rb.PushRaw(out); | ||
| 330 | } | ||
| 331 | |||
| 332 | void Read(HLERequestContext& ctx) { | ||
| 333 | const ResultVal<std::vector<u8>> res = ReadImpl(ctx.GetWriteBufferSize()); | ||
| 334 | IPC::ResponseBuilder rb{ctx, 3}; | ||
| 335 | rb.Push(res.Code()); | ||
| 336 | if (res.Succeeded()) { | ||
| 337 | rb.Push(static_cast<u32>(res->size())); | ||
| 338 | ctx.WriteBuffer(*res); | ||
| 339 | } else { | ||
| 340 | rb.Push(static_cast<u32>(0)); | ||
| 341 | } | ||
| 342 | } | ||
| 343 | |||
| 344 | void Write(HLERequestContext& ctx) { | ||
| 345 | const ResultVal<size_t> res = WriteImpl(ctx.ReadBuffer()); | ||
| 346 | IPC::ResponseBuilder rb{ctx, 3}; | ||
| 347 | rb.Push(res.Code()); | ||
| 348 | rb.Push(static_cast<u32>(res.ValueOr(0))); | ||
| 349 | } | ||
| 350 | |||
| 351 | void Pending(HLERequestContext& ctx) { | ||
| 352 | const ResultVal<s32> res = PendingImpl(); | ||
| 353 | IPC::ResponseBuilder rb{ctx, 3}; | ||
| 354 | rb.Push(res.Code()); | ||
| 355 | rb.Push<s32>(res.ValueOr(0)); | ||
| 356 | } | ||
| 357 | |||
| 358 | void SetSessionCacheMode(HLERequestContext& ctx) { | ||
| 359 | IPC::RequestParser rp{ctx}; | ||
| 360 | const u32 mode = rp.Pop<u32>(); | ||
| 361 | const Result res = SetSessionCacheModeImpl(mode); | ||
| 362 | IPC::ResponseBuilder rb{ctx, 2}; | ||
| 363 | rb.Push(res); | ||
| 364 | } | ||
| 365 | |||
| 366 | void SetOption(HLERequestContext& ctx) { | ||
| 367 | struct Parameters { | ||
| 368 | OptionType option; | ||
| 369 | s32 value; | ||
| 370 | }; | ||
| 371 | static_assert(sizeof(Parameters) == 0x8, "Parameters is an invalid size"); | ||
| 372 | |||
| 373 | IPC::RequestParser rp{ctx}; | ||
| 374 | const auto parameters = rp.PopRaw<Parameters>(); | ||
| 375 | |||
| 376 | switch (parameters.option) { | ||
| 377 | case OptionType::DoNotCloseSocket: | ||
| 378 | do_not_close_socket = static_cast<bool>(parameters.value); | ||
| 379 | break; | ||
| 380 | case OptionType::GetServerCertChain: | ||
| 381 | get_server_cert_chain = static_cast<bool>(parameters.value); | ||
| 382 | break; | ||
| 383 | default: | ||
| 384 | LOG_WARNING(Service_SSL, "Unknown option={}, value={}", parameters.option, | ||
| 385 | parameters.value); | ||
| 386 | } | ||
| 387 | |||
| 388 | IPC::ResponseBuilder rb{ctx, 2}; | ||
| 389 | rb.Push(ResultSuccess); | ||
| 390 | } | ||
| 87 | }; | 391 | }; |
| 88 | 392 | ||
| 89 | class ISslContext final : public ServiceFramework<ISslContext> { | 393 | class ISslContext final : public ServiceFramework<ISslContext> { |
| 90 | public: | 394 | public: |
| 91 | explicit ISslContext(Core::System& system_, SslVersion version) | 395 | explicit ISslContext(Core::System& system_, SslVersion version) |
| 92 | : ServiceFramework{system_, "ISslContext"}, ssl_version{version} { | 396 | : ServiceFramework{system_, "ISslContext"}, ssl_version{version}, |
| 397 | shared_data{std::make_shared<SslContextSharedData>()} { | ||
| 93 | static const FunctionInfo functions[] = { | 398 | static const FunctionInfo functions[] = { |
| 94 | {0, &ISslContext::SetOption, "SetOption"}, | 399 | {0, &ISslContext::SetOption, "SetOption"}, |
| 95 | {1, nullptr, "GetOption"}, | 400 | {1, nullptr, "GetOption"}, |
| 96 | {2, &ISslContext::CreateConnection, "CreateConnection"}, | 401 | {2, &ISslContext::CreateConnection, "CreateConnection"}, |
| 97 | {3, nullptr, "GetConnectionCount"}, | 402 | {3, &ISslContext::GetConnectionCount, "GetConnectionCount"}, |
| 98 | {4, &ISslContext::ImportServerPki, "ImportServerPki"}, | 403 | {4, &ISslContext::ImportServerPki, "ImportServerPki"}, |
| 99 | {5, &ISslContext::ImportClientPki, "ImportClientPki"}, | 404 | {5, &ISslContext::ImportClientPki, "ImportClientPki"}, |
| 100 | {6, nullptr, "RemoveServerPki"}, | 405 | {6, nullptr, "RemoveServerPki"}, |
| @@ -111,6 +416,7 @@ public: | |||
| 111 | 416 | ||
| 112 | private: | 417 | private: |
| 113 | SslVersion ssl_version; | 418 | SslVersion ssl_version; |
| 419 | std::shared_ptr<SslContextSharedData> shared_data; | ||
| 114 | 420 | ||
| 115 | void SetOption(HLERequestContext& ctx) { | 421 | void SetOption(HLERequestContext& ctx) { |
| 116 | struct Parameters { | 422 | struct Parameters { |
| @@ -130,11 +436,24 @@ private: | |||
| 130 | } | 436 | } |
| 131 | 437 | ||
| 132 | void CreateConnection(HLERequestContext& ctx) { | 438 | void CreateConnection(HLERequestContext& ctx) { |
| 133 | LOG_WARNING(Service_SSL, "(STUBBED) called"); | 439 | LOG_WARNING(Service_SSL, "called"); |
| 440 | |||
| 441 | auto backend_res = CreateSSLConnectionBackend(); | ||
| 134 | 442 | ||
| 135 | IPC::ResponseBuilder rb{ctx, 2, 0, 1}; | 443 | IPC::ResponseBuilder rb{ctx, 2, 0, 1}; |
| 444 | rb.Push(backend_res.Code()); | ||
| 445 | if (backend_res.Succeeded()) { | ||
| 446 | rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data, | ||
| 447 | std::move(*backend_res)); | ||
| 448 | } | ||
| 449 | } | ||
| 450 | |||
| 451 | void GetConnectionCount(HLERequestContext& ctx) { | ||
| 452 | LOG_DEBUG(Service_SSL, "connection_count={}", shared_data->connection_count); | ||
| 453 | |||
| 454 | IPC::ResponseBuilder rb{ctx, 3}; | ||
| 136 | rb.Push(ResultSuccess); | 455 | rb.Push(ResultSuccess); |
| 137 | rb.PushIpcInterface<ISslConnection>(system, ssl_version); | 456 | rb.Push(shared_data->connection_count); |
| 138 | } | 457 | } |
| 139 | 458 | ||
| 140 | void ImportServerPki(HLERequestContext& ctx) { | 459 | void ImportServerPki(HLERequestContext& ctx) { |
diff --git a/src/core/hle/service/ssl/ssl_backend.h b/src/core/hle/service/ssl/ssl_backend.h new file mode 100644 index 000000000..25c16bcc1 --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend.h | |||
| @@ -0,0 +1,45 @@ | |||
| 1 | // SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project | ||
| 2 | // SPDX-License-Identifier: GPL-2.0-or-later | ||
| 3 | |||
| 4 | #pragma once | ||
| 5 | |||
| 6 | #include "core/hle/result.h" | ||
| 7 | |||
| 8 | #include "common/common_types.h" | ||
| 9 | |||
| 10 | #include <memory> | ||
| 11 | #include <span> | ||
| 12 | #include <string> | ||
| 13 | #include <vector> | ||
| 14 | |||
| 15 | namespace Network { | ||
| 16 | class SocketBase; | ||
| 17 | } | ||
| 18 | |||
| 19 | namespace Service::SSL { | ||
| 20 | |||
| 21 | constexpr Result ResultNoSocket{ErrorModule::SSLSrv, 103}; | ||
| 22 | constexpr Result ResultInvalidSocket{ErrorModule::SSLSrv, 106}; | ||
| 23 | constexpr Result ResultTimeout{ErrorModule::SSLSrv, 205}; | ||
| 24 | constexpr Result ResultInternalError{ErrorModule::SSLSrv, 999}; // made up | ||
| 25 | |||
| 26 | // ResultWouldBlock is returned from Read and Write, and oddly, DoHandshake, | ||
| 27 | // with no way in the latter case to distinguish whether the client should poll | ||
| 28 | // for read or write. The one official client I've seen handles this by always | ||
| 29 | // polling for read (with a timeout). | ||
| 30 | constexpr Result ResultWouldBlock{ErrorModule::SSLSrv, 204}; | ||
| 31 | |||
| 32 | class SSLConnectionBackend { | ||
| 33 | public: | ||
| 34 | virtual ~SSLConnectionBackend() {} | ||
| 35 | virtual void SetSocket(std::shared_ptr<Network::SocketBase> socket) = 0; | ||
| 36 | virtual Result SetHostName(const std::string& hostname) = 0; | ||
| 37 | virtual Result DoHandshake() = 0; | ||
| 38 | virtual ResultVal<size_t> Read(std::span<u8> data) = 0; | ||
| 39 | virtual ResultVal<size_t> Write(std::span<const u8> data) = 0; | ||
| 40 | virtual ResultVal<std::vector<std::vector<u8>>> GetServerCerts() = 0; | ||
| 41 | }; | ||
| 42 | |||
| 43 | ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend(); | ||
| 44 | |||
| 45 | } // namespace Service::SSL | ||
diff --git a/src/core/hle/service/ssl/ssl_backend_none.cpp b/src/core/hle/service/ssl/ssl_backend_none.cpp new file mode 100644 index 000000000..f2f0ef706 --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_none.cpp | |||
| @@ -0,0 +1,16 @@ | |||
| 1 | // SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project | ||
| 2 | // SPDX-License-Identifier: GPL-2.0-or-later | ||
| 3 | |||
| 4 | #include "core/hle/service/ssl/ssl_backend.h" | ||
| 5 | |||
| 6 | #include "common/logging/log.h" | ||
| 7 | |||
| 8 | namespace Service::SSL { | ||
| 9 | |||
| 10 | ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { | ||
| 11 | LOG_ERROR(Service_SSL, | ||
| 12 | "Can't create SSL connection because no SSL backend is available on this platform"); | ||
| 13 | return ResultInternalError; | ||
| 14 | } | ||
| 15 | |||
| 16 | } // namespace Service::SSL | ||
diff --git a/src/core/hle/service/ssl/ssl_backend_openssl.cpp b/src/core/hle/service/ssl/ssl_backend_openssl.cpp new file mode 100644 index 000000000..f69674f77 --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_openssl.cpp | |||
| @@ -0,0 +1,351 @@ | |||
| 1 | // SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project | ||
| 2 | // SPDX-License-Identifier: GPL-2.0-or-later | ||
| 3 | |||
| 4 | #include "core/hle/service/ssl/ssl_backend.h" | ||
| 5 | #include "core/internal_network/network.h" | ||
| 6 | #include "core/internal_network/sockets.h" | ||
| 7 | |||
| 8 | #include "common/fs/file.h" | ||
| 9 | #include "common/hex_util.h" | ||
| 10 | #include "common/string_util.h" | ||
| 11 | |||
| 12 | #include <mutex> | ||
| 13 | |||
| 14 | #include <openssl/bio.h> | ||
| 15 | #include <openssl/err.h> | ||
| 16 | #include <openssl/ssl.h> | ||
| 17 | #include <openssl/x509.h> | ||
| 18 | |||
| 19 | using namespace Common::FS; | ||
| 20 | |||
| 21 | namespace Service::SSL { | ||
| 22 | |||
| 23 | // Import OpenSSL's `SSL` type into the namespace. This is needed because the | ||
| 24 | // namespace is also named `SSL`. | ||
| 25 | using ::SSL; | ||
| 26 | |||
| 27 | namespace { | ||
| 28 | |||
| 29 | std::once_flag one_time_init_flag; | ||
| 30 | bool one_time_init_success = false; | ||
| 31 | |||
| 32 | SSL_CTX* ssl_ctx; | ||
| 33 | IOFile key_log_file; // only open if SSLKEYLOGFILE set in environment | ||
| 34 | BIO_METHOD* bio_meth; | ||
| 35 | |||
| 36 | Result CheckOpenSSLErrors(); | ||
| 37 | void OneTimeInit(); | ||
| 38 | void OneTimeInitLogFile(); | ||
| 39 | bool OneTimeInitBIO(); | ||
| 40 | |||
| 41 | } // namespace | ||
| 42 | |||
| 43 | class SSLConnectionBackendOpenSSL final : public SSLConnectionBackend { | ||
| 44 | public: | ||
| 45 | Result Init() { | ||
| 46 | std::call_once(one_time_init_flag, OneTimeInit); | ||
| 47 | |||
| 48 | if (!one_time_init_success) { | ||
| 49 | LOG_ERROR(Service_SSL, | ||
| 50 | "Can't create SSL connection because OpenSSL one-time initialization failed"); | ||
| 51 | return ResultInternalError; | ||
| 52 | } | ||
| 53 | |||
| 54 | ssl = SSL_new(ssl_ctx); | ||
| 55 | if (!ssl) { | ||
| 56 | LOG_ERROR(Service_SSL, "SSL_new failed"); | ||
| 57 | return CheckOpenSSLErrors(); | ||
| 58 | } | ||
| 59 | |||
| 60 | SSL_set_connect_state(ssl); | ||
| 61 | |||
| 62 | bio = BIO_new(bio_meth); | ||
| 63 | if (!bio) { | ||
| 64 | LOG_ERROR(Service_SSL, "BIO_new failed"); | ||
| 65 | return CheckOpenSSLErrors(); | ||
| 66 | } | ||
| 67 | |||
| 68 | BIO_set_data(bio, this); | ||
| 69 | BIO_set_init(bio, 1); | ||
| 70 | SSL_set_bio(ssl, bio, bio); | ||
| 71 | |||
| 72 | return ResultSuccess; | ||
| 73 | } | ||
| 74 | |||
| 75 | void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override { | ||
| 76 | socket = std::move(socket_in); | ||
| 77 | } | ||
| 78 | |||
| 79 | Result SetHostName(const std::string& hostname) override { | ||
| 80 | if (!SSL_set1_host(ssl, hostname.c_str())) { // hostname for verification | ||
| 81 | LOG_ERROR(Service_SSL, "SSL_set1_host({}) failed", hostname); | ||
| 82 | return CheckOpenSSLErrors(); | ||
| 83 | } | ||
| 84 | if (!SSL_set_tlsext_host_name(ssl, hostname.c_str())) { // hostname for SNI | ||
| 85 | LOG_ERROR(Service_SSL, "SSL_set_tlsext_host_name({}) failed", hostname); | ||
| 86 | return CheckOpenSSLErrors(); | ||
| 87 | } | ||
| 88 | return ResultSuccess; | ||
| 89 | } | ||
| 90 | |||
| 91 | Result DoHandshake() override { | ||
| 92 | SSL_set_verify_result(ssl, X509_V_OK); | ||
| 93 | const int ret = SSL_do_handshake(ssl); | ||
| 94 | const long verify_result = SSL_get_verify_result(ssl); | ||
| 95 | if (verify_result != X509_V_OK) { | ||
| 96 | LOG_ERROR(Service_SSL, "SSL cert verification failed because: {}", | ||
| 97 | X509_verify_cert_error_string(verify_result)); | ||
| 98 | return CheckOpenSSLErrors(); | ||
| 99 | } | ||
| 100 | if (ret <= 0) { | ||
| 101 | const int ssl_err = SSL_get_error(ssl, ret); | ||
| 102 | if (ssl_err == SSL_ERROR_ZERO_RETURN || | ||
| 103 | (ssl_err == SSL_ERROR_SYSCALL && got_read_eof)) { | ||
| 104 | LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up"); | ||
| 105 | return ResultInternalError; | ||
| 106 | } | ||
| 107 | } | ||
| 108 | return HandleReturn("SSL_do_handshake", 0, ret).Code(); | ||
| 109 | } | ||
| 110 | |||
| 111 | ResultVal<size_t> Read(std::span<u8> data) override { | ||
| 112 | size_t actual; | ||
| 113 | const int ret = SSL_read_ex(ssl, data.data(), data.size(), &actual); | ||
| 114 | return HandleReturn("SSL_read_ex", actual, ret); | ||
| 115 | } | ||
| 116 | |||
| 117 | ResultVal<size_t> Write(std::span<const u8> data) override { | ||
| 118 | size_t actual; | ||
| 119 | const int ret = SSL_write_ex(ssl, data.data(), data.size(), &actual); | ||
| 120 | return HandleReturn("SSL_write_ex", actual, ret); | ||
| 121 | } | ||
| 122 | |||
| 123 | ResultVal<size_t> HandleReturn(const char* what, size_t actual, int ret) { | ||
| 124 | const int ssl_err = SSL_get_error(ssl, ret); | ||
| 125 | CheckOpenSSLErrors(); | ||
| 126 | switch (ssl_err) { | ||
| 127 | case SSL_ERROR_NONE: | ||
| 128 | return actual; | ||
| 129 | case SSL_ERROR_ZERO_RETURN: | ||
| 130 | LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_ZERO_RETURN", what); | ||
| 131 | // DoHandshake special-cases this, but for Read and Write: | ||
| 132 | return size_t(0); | ||
| 133 | case SSL_ERROR_WANT_READ: | ||
| 134 | LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_READ", what); | ||
| 135 | return ResultWouldBlock; | ||
| 136 | case SSL_ERROR_WANT_WRITE: | ||
| 137 | LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_WRITE", what); | ||
| 138 | return ResultWouldBlock; | ||
| 139 | default: | ||
| 140 | if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof) { | ||
| 141 | LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what); | ||
| 142 | return size_t(0); | ||
| 143 | } | ||
| 144 | LOG_ERROR(Service_SSL, "{} => other SSL_get_error return value {}", what, ssl_err); | ||
| 145 | return ResultInternalError; | ||
| 146 | } | ||
| 147 | } | ||
| 148 | |||
| 149 | ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { | ||
| 150 | STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl); | ||
| 151 | if (!chain) { | ||
| 152 | LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr"); | ||
| 153 | return ResultInternalError; | ||
| 154 | } | ||
| 155 | std::vector<std::vector<u8>> ret; | ||
| 156 | int count = sk_X509_num(chain); | ||
| 157 | ASSERT(count >= 0); | ||
| 158 | for (int i = 0; i < count; i++) { | ||
| 159 | X509* x509 = sk_X509_value(chain, i); | ||
| 160 | ASSERT_OR_EXECUTE(x509 != nullptr, { continue; }); | ||
| 161 | unsigned char* buf = nullptr; | ||
| 162 | int len = i2d_X509(x509, &buf); | ||
| 163 | ASSERT_OR_EXECUTE(len >= 0 && buf, { continue; }); | ||
| 164 | ret.emplace_back(buf, buf + len); | ||
| 165 | OPENSSL_free(buf); | ||
| 166 | } | ||
| 167 | return ret; | ||
| 168 | } | ||
| 169 | |||
| 170 | ~SSLConnectionBackendOpenSSL() { | ||
| 171 | // these are null-tolerant: | ||
| 172 | SSL_free(ssl); | ||
| 173 | BIO_free(bio); | ||
| 174 | } | ||
| 175 | |||
| 176 | static void KeyLogCallback(const SSL* ssl, const char* line) { | ||
| 177 | std::string str(line); | ||
| 178 | str.push_back('\n'); | ||
| 179 | // Do this in a single WriteString for atomicity if multiple instances | ||
| 180 | // are running on different threads (though that can't currently | ||
| 181 | // happen). | ||
| 182 | if (key_log_file.WriteString(str) != str.size() || !key_log_file.Flush()) { | ||
| 183 | LOG_CRITICAL(Service_SSL, "Failed to write to SSLKEYLOGFILE"); | ||
| 184 | } | ||
| 185 | LOG_DEBUG(Service_SSL, "Wrote to SSLKEYLOGFILE: {}", line); | ||
| 186 | } | ||
| 187 | |||
| 188 | static int WriteCallback(BIO* bio, const char* buf, size_t len, size_t* actual_p) { | ||
| 189 | auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio)); | ||
| 190 | ASSERT_OR_EXECUTE_MSG( | ||
| 191 | self->socket, { return 0; }, "OpenSSL asked to send but we have no socket"); | ||
| 192 | BIO_clear_retry_flags(bio); | ||
| 193 | auto [actual, err] = self->socket->Send({reinterpret_cast<const u8*>(buf), len}, 0); | ||
| 194 | switch (err) { | ||
| 195 | case Network::Errno::SUCCESS: | ||
| 196 | *actual_p = actual; | ||
| 197 | return 1; | ||
| 198 | case Network::Errno::AGAIN: | ||
| 199 | BIO_set_flags(bio, BIO_FLAGS_WRITE | BIO_FLAGS_SHOULD_RETRY); | ||
| 200 | return 0; | ||
| 201 | default: | ||
| 202 | LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err); | ||
| 203 | return -1; | ||
| 204 | } | ||
| 205 | } | ||
| 206 | |||
| 207 | static int ReadCallback(BIO* bio, char* buf, size_t len, size_t* actual_p) { | ||
| 208 | auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio)); | ||
| 209 | ASSERT_OR_EXECUTE_MSG( | ||
| 210 | self->socket, { return 0; }, "OpenSSL asked to recv but we have no socket"); | ||
| 211 | BIO_clear_retry_flags(bio); | ||
| 212 | auto [actual, err] = self->socket->Recv(0, {reinterpret_cast<u8*>(buf), len}); | ||
| 213 | switch (err) { | ||
| 214 | case Network::Errno::SUCCESS: | ||
| 215 | *actual_p = actual; | ||
| 216 | if (actual == 0) { | ||
| 217 | self->got_read_eof = true; | ||
| 218 | } | ||
| 219 | return actual ? 1 : 0; | ||
| 220 | case Network::Errno::AGAIN: | ||
| 221 | BIO_set_flags(bio, BIO_FLAGS_READ | BIO_FLAGS_SHOULD_RETRY); | ||
| 222 | return 0; | ||
| 223 | default: | ||
| 224 | LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err); | ||
| 225 | return -1; | ||
| 226 | } | ||
| 227 | } | ||
| 228 | |||
| 229 | static long CtrlCallback(BIO* bio, int cmd, long l_arg, void* p_arg) { | ||
| 230 | switch (cmd) { | ||
| 231 | case BIO_CTRL_FLUSH: | ||
| 232 | // Nothing to flush. | ||
| 233 | return 1; | ||
| 234 | case BIO_CTRL_PUSH: | ||
| 235 | case BIO_CTRL_POP: | ||
| 236 | #ifdef BIO_CTRL_GET_KTLS_SEND | ||
| 237 | case BIO_CTRL_GET_KTLS_SEND: | ||
| 238 | case BIO_CTRL_GET_KTLS_RECV: | ||
| 239 | #endif | ||
| 240 | // We don't support these operations, but don't bother logging them | ||
| 241 | // as they're nothing unusual. | ||
| 242 | return 0; | ||
| 243 | default: | ||
| 244 | LOG_DEBUG(Service_SSL, "OpenSSL BIO got ctrl({}, {}, {})", cmd, l_arg, p_arg); | ||
| 245 | return 0; | ||
| 246 | } | ||
| 247 | } | ||
| 248 | |||
| 249 | SSL* ssl = nullptr; | ||
| 250 | BIO* bio = nullptr; | ||
| 251 | bool got_read_eof = false; | ||
| 252 | |||
| 253 | std::shared_ptr<Network::SocketBase> socket; | ||
| 254 | }; | ||
| 255 | |||
| 256 | ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { | ||
| 257 | auto conn = std::make_unique<SSLConnectionBackendOpenSSL>(); | ||
| 258 | const Result res = conn->Init(); | ||
| 259 | if (res.IsFailure()) { | ||
| 260 | return res; | ||
| 261 | } | ||
| 262 | return conn; | ||
| 263 | } | ||
| 264 | |||
| 265 | namespace { | ||
| 266 | |||
| 267 | Result CheckOpenSSLErrors() { | ||
| 268 | unsigned long rc; | ||
| 269 | const char* file; | ||
| 270 | int line; | ||
| 271 | const char* func; | ||
| 272 | const char* data; | ||
| 273 | int flags; | ||
| 274 | #if OPENSSL_VERSION_NUMBER >= 0x30000000L | ||
| 275 | while ((rc = ERR_get_error_all(&file, &line, &func, &data, &flags))) | ||
| 276 | #else | ||
| 277 | // Can't get function names from OpenSSL on this version, so use mine: | ||
| 278 | func = __func__; | ||
| 279 | while ((rc = ERR_get_error_line_data(&file, &line, &data, &flags))) | ||
| 280 | #endif | ||
| 281 | { | ||
| 282 | std::string msg; | ||
| 283 | msg.resize(1024, '\0'); | ||
| 284 | ERR_error_string_n(rc, msg.data(), msg.size()); | ||
| 285 | msg.resize(strlen(msg.data()), '\0'); | ||
| 286 | if (flags & ERR_TXT_STRING) { | ||
| 287 | msg.append(" | "); | ||
| 288 | msg.append(data); | ||
| 289 | } | ||
| 290 | Common::Log::FmtLogMessage(Common::Log::Class::Service_SSL, Common::Log::Level::Error, | ||
| 291 | Common::Log::TrimSourcePath(file), line, func, "OpenSSL: {}", | ||
| 292 | msg); | ||
| 293 | } | ||
| 294 | return ResultInternalError; | ||
| 295 | } | ||
| 296 | |||
| 297 | void OneTimeInit() { | ||
| 298 | ssl_ctx = SSL_CTX_new(TLS_client_method()); | ||
| 299 | if (!ssl_ctx) { | ||
| 300 | LOG_ERROR(Service_SSL, "SSL_CTX_new failed"); | ||
| 301 | CheckOpenSSLErrors(); | ||
| 302 | return; | ||
| 303 | } | ||
| 304 | |||
| 305 | SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, nullptr); | ||
| 306 | |||
| 307 | if (!SSL_CTX_set_default_verify_paths(ssl_ctx)) { | ||
| 308 | LOG_ERROR(Service_SSL, "SSL_CTX_set_default_verify_paths failed"); | ||
| 309 | CheckOpenSSLErrors(); | ||
| 310 | return; | ||
| 311 | } | ||
| 312 | |||
| 313 | OneTimeInitLogFile(); | ||
| 314 | |||
| 315 | if (!OneTimeInitBIO()) { | ||
| 316 | return; | ||
| 317 | } | ||
| 318 | |||
| 319 | one_time_init_success = true; | ||
| 320 | } | ||
| 321 | |||
| 322 | void OneTimeInitLogFile() { | ||
| 323 | const char* logfile = getenv("SSLKEYLOGFILE"); | ||
| 324 | if (logfile) { | ||
| 325 | key_log_file.Open(logfile, FileAccessMode::Append, FileType::TextFile, | ||
| 326 | FileShareFlag::ShareWriteOnly); | ||
| 327 | if (key_log_file.IsOpen()) { | ||
| 328 | SSL_CTX_set_keylog_callback(ssl_ctx, &SSLConnectionBackendOpenSSL::KeyLogCallback); | ||
| 329 | } else { | ||
| 330 | LOG_CRITICAL(Service_SSL, | ||
| 331 | "SSLKEYLOGFILE was set but file could not be opened; not logging keys!"); | ||
| 332 | } | ||
| 333 | } | ||
| 334 | } | ||
| 335 | |||
| 336 | bool OneTimeInitBIO() { | ||
| 337 | bio_meth = | ||
| 338 | BIO_meth_new(BIO_get_new_index() | BIO_TYPE_SOURCE_SINK, "SSLConnectionBackendOpenSSL"); | ||
| 339 | if (!bio_meth || | ||
| 340 | !BIO_meth_set_write_ex(bio_meth, &SSLConnectionBackendOpenSSL::WriteCallback) || | ||
| 341 | !BIO_meth_set_read_ex(bio_meth, &SSLConnectionBackendOpenSSL::ReadCallback) || | ||
| 342 | !BIO_meth_set_ctrl(bio_meth, &SSLConnectionBackendOpenSSL::CtrlCallback)) { | ||
| 343 | LOG_ERROR(Service_SSL, "Failed to create BIO_METHOD"); | ||
| 344 | return false; | ||
| 345 | } | ||
| 346 | return true; | ||
| 347 | } | ||
| 348 | |||
| 349 | } // namespace | ||
| 350 | |||
| 351 | } // namespace Service::SSL | ||
diff --git a/src/core/hle/service/ssl/ssl_backend_schannel.cpp b/src/core/hle/service/ssl/ssl_backend_schannel.cpp new file mode 100644 index 000000000..a1d6a186e --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_schannel.cpp | |||
| @@ -0,0 +1,543 @@ | |||
| 1 | // SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project | ||
| 2 | // SPDX-License-Identifier: GPL-2.0-or-later | ||
| 3 | |||
| 4 | #include "core/hle/service/ssl/ssl_backend.h" | ||
| 5 | #include "core/internal_network/network.h" | ||
| 6 | #include "core/internal_network/sockets.h" | ||
| 7 | |||
| 8 | #include "common/error.h" | ||
| 9 | #include "common/fs/file.h" | ||
| 10 | #include "common/hex_util.h" | ||
| 11 | #include "common/string_util.h" | ||
| 12 | |||
| 13 | #include <mutex> | ||
| 14 | |||
| 15 | namespace { | ||
| 16 | |||
| 17 | // These includes are inside the namespace to avoid a conflict on MinGW where | ||
| 18 | // the headers define an enum containing Network and Service as enumerators | ||
| 19 | // (which clash with the correspondingly named namespaces). | ||
| 20 | #define SECURITY_WIN32 | ||
| 21 | #include <schnlsp.h> | ||
| 22 | #include <security.h> | ||
| 23 | |||
| 24 | std::once_flag one_time_init_flag; | ||
| 25 | bool one_time_init_success = false; | ||
| 26 | |||
| 27 | SCHANNEL_CRED schannel_cred{}; | ||
| 28 | CredHandle cred_handle; | ||
| 29 | |||
| 30 | static void OneTimeInit() { | ||
| 31 | schannel_cred.dwVersion = SCHANNEL_CRED_VERSION; | ||
| 32 | schannel_cred.dwFlags = | ||
| 33 | SCH_USE_STRONG_CRYPTO | // don't allow insecure protocols | ||
| 34 | SCH_CRED_AUTO_CRED_VALIDATION | // validate certs | ||
| 35 | SCH_CRED_NO_DEFAULT_CREDS; // don't automatically present a client certificate | ||
| 36 | // ^ I'm assuming that nobody would want to connect Yuzu to a | ||
| 37 | // service that requires some OS-provided corporate client | ||
| 38 | // certificate, and presenting one to some arbitrary server | ||
| 39 | // might be a privacy concern? Who knows, though. | ||
| 40 | |||
| 41 | const SECURITY_STATUS ret = | ||
| 42 | AcquireCredentialsHandle(nullptr, const_cast<LPTSTR>(UNISP_NAME), SECPKG_CRED_OUTBOUND, | ||
| 43 | nullptr, &schannel_cred, nullptr, nullptr, &cred_handle, nullptr); | ||
| 44 | if (ret != SEC_E_OK) { | ||
| 45 | // SECURITY_STATUS codes are a type of HRESULT and can be used with NativeErrorToString. | ||
| 46 | LOG_ERROR(Service_SSL, "AcquireCredentialsHandle failed: {}", | ||
| 47 | Common::NativeErrorToString(ret)); | ||
| 48 | return; | ||
| 49 | } | ||
| 50 | |||
| 51 | if (getenv("SSLKEYLOGFILE")) { | ||
| 52 | LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but Schannel does not support exporting " | ||
| 53 | "keys; not logging keys!"); | ||
| 54 | // Not fatal. | ||
| 55 | } | ||
| 56 | |||
| 57 | one_time_init_success = true; | ||
| 58 | } | ||
| 59 | |||
| 60 | } // namespace | ||
| 61 | |||
| 62 | namespace Service::SSL { | ||
| 63 | |||
| 64 | class SSLConnectionBackendSchannel final : public SSLConnectionBackend { | ||
| 65 | public: | ||
| 66 | Result Init() { | ||
| 67 | std::call_once(one_time_init_flag, OneTimeInit); | ||
| 68 | |||
| 69 | if (!one_time_init_success) { | ||
| 70 | LOG_ERROR( | ||
| 71 | Service_SSL, | ||
| 72 | "Can't create SSL connection because Schannel one-time initialization failed"); | ||
| 73 | return ResultInternalError; | ||
| 74 | } | ||
| 75 | |||
| 76 | return ResultSuccess; | ||
| 77 | } | ||
| 78 | |||
| 79 | void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override { | ||
| 80 | socket = std::move(socket_in); | ||
| 81 | } | ||
| 82 | |||
| 83 | Result SetHostName(const std::string& hostname_in) override { | ||
| 84 | hostname = hostname_in; | ||
| 85 | return ResultSuccess; | ||
| 86 | } | ||
| 87 | |||
| 88 | Result DoHandshake() override { | ||
| 89 | while (1) { | ||
| 90 | Result r; | ||
| 91 | switch (handshake_state) { | ||
| 92 | case HandshakeState::Initial: | ||
| 93 | if ((r = FlushCiphertextWriteBuf()) != ResultSuccess || | ||
| 94 | (r = CallInitializeSecurityContext()) != ResultSuccess) { | ||
| 95 | return r; | ||
| 96 | } | ||
| 97 | // CallInitializeSecurityContext updated `handshake_state`. | ||
| 98 | continue; | ||
| 99 | case HandshakeState::ContinueNeeded: | ||
| 100 | case HandshakeState::IncompleteMessage: | ||
| 101 | if ((r = FlushCiphertextWriteBuf()) != ResultSuccess || | ||
| 102 | (r = FillCiphertextReadBuf()) != ResultSuccess) { | ||
| 103 | return r; | ||
| 104 | } | ||
| 105 | if (ciphertext_read_buf.empty()) { | ||
| 106 | LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up"); | ||
| 107 | return ResultInternalError; | ||
| 108 | } | ||
| 109 | if ((r = CallInitializeSecurityContext()) != ResultSuccess) { | ||
| 110 | return r; | ||
| 111 | } | ||
| 112 | // CallInitializeSecurityContext updated `handshake_state`. | ||
| 113 | continue; | ||
| 114 | case HandshakeState::DoneAfterFlush: | ||
| 115 | if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) { | ||
| 116 | return r; | ||
| 117 | } | ||
| 118 | handshake_state = HandshakeState::Connected; | ||
| 119 | return ResultSuccess; | ||
| 120 | case HandshakeState::Connected: | ||
| 121 | LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook"); | ||
| 122 | return ResultInternalError; | ||
| 123 | case HandshakeState::Error: | ||
| 124 | return ResultInternalError; | ||
| 125 | } | ||
| 126 | } | ||
| 127 | } | ||
| 128 | |||
| 129 | Result FillCiphertextReadBuf() { | ||
| 130 | const size_t fill_size = read_buf_fill_size ? read_buf_fill_size : 4096; | ||
| 131 | read_buf_fill_size = 0; | ||
| 132 | // This unnecessarily zeroes the buffer; oh well. | ||
| 133 | const size_t offset = ciphertext_read_buf.size(); | ||
| 134 | ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; }); | ||
| 135 | ciphertext_read_buf.resize(offset + fill_size, 0); | ||
| 136 | const auto read_span = std::span(ciphertext_read_buf).subspan(offset, fill_size); | ||
| 137 | const auto [actual, err] = socket->Recv(0, read_span); | ||
| 138 | switch (err) { | ||
| 139 | case Network::Errno::SUCCESS: | ||
| 140 | ASSERT(static_cast<size_t>(actual) <= fill_size); | ||
| 141 | ciphertext_read_buf.resize(offset + actual); | ||
| 142 | return ResultSuccess; | ||
| 143 | case Network::Errno::AGAIN: | ||
| 144 | ciphertext_read_buf.resize(offset); | ||
| 145 | return ResultWouldBlock; | ||
| 146 | default: | ||
| 147 | ciphertext_read_buf.resize(offset); | ||
| 148 | LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err); | ||
| 149 | return ResultInternalError; | ||
| 150 | } | ||
| 151 | } | ||
| 152 | |||
| 153 | // Returns success if the write buffer has been completely emptied. | ||
| 154 | Result FlushCiphertextWriteBuf() { | ||
| 155 | while (!ciphertext_write_buf.empty()) { | ||
| 156 | const auto [actual, err] = socket->Send(ciphertext_write_buf, 0); | ||
| 157 | switch (err) { | ||
| 158 | case Network::Errno::SUCCESS: | ||
| 159 | ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf.size()); | ||
| 160 | ciphertext_write_buf.erase(ciphertext_write_buf.begin(), | ||
| 161 | ciphertext_write_buf.begin() + actual); | ||
| 162 | break; | ||
| 163 | case Network::Errno::AGAIN: | ||
| 164 | return ResultWouldBlock; | ||
| 165 | default: | ||
| 166 | LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err); | ||
| 167 | return ResultInternalError; | ||
| 168 | } | ||
| 169 | } | ||
| 170 | return ResultSuccess; | ||
| 171 | } | ||
| 172 | |||
| 173 | Result CallInitializeSecurityContext() { | ||
| 174 | const unsigned long req = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | | ||
| 175 | ISC_REQ_INTEGRITY | ISC_REQ_REPLAY_DETECT | | ||
| 176 | ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM | | ||
| 177 | ISC_REQ_USE_SUPPLIED_CREDS; | ||
| 178 | unsigned long attr; | ||
| 179 | // https://learn.microsoft.com/en-us/windows/win32/secauthn/initializesecuritycontext--schannel | ||
| 180 | std::array<SecBuffer, 2> input_buffers{{ | ||
| 181 | // only used if `initial_call_done` | ||
| 182 | { | ||
| 183 | // [0] | ||
| 184 | .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()), | ||
| 185 | .BufferType = SECBUFFER_TOKEN, | ||
| 186 | .pvBuffer = ciphertext_read_buf.data(), | ||
| 187 | }, | ||
| 188 | { | ||
| 189 | // [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is | ||
| 190 | // returned, or SECBUFFER_EXTRA when SEC_E_CONTINUE_NEEDED is returned if the | ||
| 191 | // whole buffer wasn't used) | ||
| 192 | .cbBuffer = 0, | ||
| 193 | .BufferType = SECBUFFER_EMPTY, | ||
| 194 | .pvBuffer = nullptr, | ||
| 195 | }, | ||
| 196 | }}; | ||
| 197 | std::array<SecBuffer, 2> output_buffers{{ | ||
| 198 | { | ||
| 199 | .cbBuffer = 0, | ||
| 200 | .BufferType = SECBUFFER_TOKEN, | ||
| 201 | .pvBuffer = nullptr, | ||
| 202 | }, // [0] | ||
| 203 | { | ||
| 204 | .cbBuffer = 0, | ||
| 205 | .BufferType = SECBUFFER_ALERT, | ||
| 206 | .pvBuffer = nullptr, | ||
| 207 | }, // [1] | ||
| 208 | }}; | ||
| 209 | SecBufferDesc input_desc{ | ||
| 210 | .ulVersion = SECBUFFER_VERSION, | ||
| 211 | .cBuffers = static_cast<unsigned long>(input_buffers.size()), | ||
| 212 | .pBuffers = input_buffers.data(), | ||
| 213 | }; | ||
| 214 | SecBufferDesc output_desc{ | ||
| 215 | .ulVersion = SECBUFFER_VERSION, | ||
| 216 | .cBuffers = static_cast<unsigned long>(output_buffers.size()), | ||
| 217 | .pBuffers = output_buffers.data(), | ||
| 218 | }; | ||
| 219 | ASSERT_OR_EXECUTE_MSG( | ||
| 220 | input_buffers[0].cbBuffer == ciphertext_read_buf.size(), | ||
| 221 | { return ResultInternalError; }, "read buffer too large"); | ||
| 222 | |||
| 223 | bool initial_call_done = handshake_state != HandshakeState::Initial; | ||
| 224 | if (initial_call_done) { | ||
| 225 | LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext", | ||
| 226 | ciphertext_read_buf.size()); | ||
| 227 | } | ||
| 228 | |||
| 229 | const SECURITY_STATUS ret = | ||
| 230 | InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt : nullptr, | ||
| 231 | // Caller ensured we have set a hostname: | ||
| 232 | const_cast<char*>(hostname.value().c_str()), req, | ||
| 233 | 0, // Reserved1 | ||
| 234 | 0, // TargetDataRep not used with Schannel | ||
| 235 | initial_call_done ? &input_desc : nullptr, | ||
| 236 | 0, // Reserved2 | ||
| 237 | initial_call_done ? nullptr : &ctxt, &output_desc, &attr, | ||
| 238 | nullptr); // ptsExpiry | ||
| 239 | |||
| 240 | if (output_buffers[0].pvBuffer) { | ||
| 241 | const std::span span(static_cast<u8*>(output_buffers[0].pvBuffer), | ||
| 242 | output_buffers[0].cbBuffer); | ||
| 243 | ciphertext_write_buf.insert(ciphertext_write_buf.end(), span.begin(), span.end()); | ||
| 244 | FreeContextBuffer(output_buffers[0].pvBuffer); | ||
| 245 | } | ||
| 246 | |||
| 247 | if (output_buffers[1].pvBuffer) { | ||
| 248 | const std::span span(static_cast<u8*>(output_buffers[1].pvBuffer), | ||
| 249 | output_buffers[1].cbBuffer); | ||
| 250 | // The documentation doesn't explain what format this data is in. | ||
| 251 | LOG_DEBUG(Service_SSL, "Got a {}-byte alert buffer: {}", span.size(), | ||
| 252 | Common::HexToString(span)); | ||
| 253 | } | ||
| 254 | |||
| 255 | switch (ret) { | ||
| 256 | case SEC_I_CONTINUE_NEEDED: | ||
| 257 | LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED"); | ||
| 258 | if (input_buffers[1].BufferType == SECBUFFER_EXTRA) { | ||
| 259 | LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer); | ||
| 260 | ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf.size()); | ||
| 261 | ciphertext_read_buf.erase(ciphertext_read_buf.begin(), | ||
| 262 | ciphertext_read_buf.end() - input_buffers[1].cbBuffer); | ||
| 263 | } else { | ||
| 264 | ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY); | ||
| 265 | ciphertext_read_buf.clear(); | ||
| 266 | } | ||
| 267 | handshake_state = HandshakeState::ContinueNeeded; | ||
| 268 | return ResultSuccess; | ||
| 269 | case SEC_E_INCOMPLETE_MESSAGE: | ||
| 270 | LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE"); | ||
| 271 | ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING); | ||
| 272 | read_buf_fill_size = input_buffers[1].cbBuffer; | ||
| 273 | handshake_state = HandshakeState::IncompleteMessage; | ||
| 274 | return ResultSuccess; | ||
| 275 | case SEC_E_OK: | ||
| 276 | LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK"); | ||
| 277 | ciphertext_read_buf.clear(); | ||
| 278 | handshake_state = HandshakeState::DoneAfterFlush; | ||
| 279 | return GrabStreamSizes(); | ||
| 280 | default: | ||
| 281 | LOG_ERROR(Service_SSL, | ||
| 282 | "InitializeSecurityContext failed (probably certificate/protocol issue): {}", | ||
| 283 | Common::NativeErrorToString(ret)); | ||
| 284 | handshake_state = HandshakeState::Error; | ||
| 285 | return ResultInternalError; | ||
| 286 | } | ||
| 287 | } | ||
| 288 | |||
| 289 | Result GrabStreamSizes() { | ||
| 290 | const SECURITY_STATUS ret = | ||
| 291 | QueryContextAttributes(&ctxt, SECPKG_ATTR_STREAM_SIZES, &stream_sizes); | ||
| 292 | if (ret != SEC_E_OK) { | ||
| 293 | LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}", | ||
| 294 | Common::NativeErrorToString(ret)); | ||
| 295 | handshake_state = HandshakeState::Error; | ||
| 296 | return ResultInternalError; | ||
| 297 | } | ||
| 298 | return ResultSuccess; | ||
| 299 | } | ||
| 300 | |||
| 301 | ResultVal<size_t> Read(std::span<u8> data) override { | ||
| 302 | if (handshake_state != HandshakeState::Connected) { | ||
| 303 | LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake"); | ||
| 304 | return ResultInternalError; | ||
| 305 | } | ||
| 306 | if (data.size() == 0 || got_read_eof) { | ||
| 307 | return size_t(0); | ||
| 308 | } | ||
| 309 | while (1) { | ||
| 310 | if (!cleartext_read_buf.empty()) { | ||
| 311 | const size_t read_size = std::min(cleartext_read_buf.size(), data.size()); | ||
| 312 | std::memcpy(data.data(), cleartext_read_buf.data(), read_size); | ||
| 313 | cleartext_read_buf.erase(cleartext_read_buf.begin(), | ||
| 314 | cleartext_read_buf.begin() + read_size); | ||
| 315 | return read_size; | ||
| 316 | } | ||
| 317 | if (!ciphertext_read_buf.empty()) { | ||
| 318 | SecBuffer empty{ | ||
| 319 | .cbBuffer = 0, | ||
| 320 | .BufferType = SECBUFFER_EMPTY, | ||
| 321 | .pvBuffer = nullptr, | ||
| 322 | }; | ||
| 323 | std::array<SecBuffer, 5> buffers{{ | ||
| 324 | { | ||
| 325 | .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()), | ||
| 326 | .BufferType = SECBUFFER_DATA, | ||
| 327 | .pvBuffer = ciphertext_read_buf.data(), | ||
| 328 | }, | ||
| 329 | empty, | ||
| 330 | empty, | ||
| 331 | empty, | ||
| 332 | }}; | ||
| 333 | ASSERT_OR_EXECUTE_MSG( | ||
| 334 | buffers[0].cbBuffer == ciphertext_read_buf.size(), | ||
| 335 | { return ResultInternalError; }, "read buffer too large"); | ||
| 336 | SecBufferDesc desc{ | ||
| 337 | .ulVersion = SECBUFFER_VERSION, | ||
| 338 | .cBuffers = static_cast<unsigned long>(buffers.size()), | ||
| 339 | .pBuffers = buffers.data(), | ||
| 340 | }; | ||
| 341 | SECURITY_STATUS ret = | ||
| 342 | DecryptMessage(&ctxt, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr); | ||
| 343 | switch (ret) { | ||
| 344 | case SEC_E_OK: | ||
| 345 | ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER, | ||
| 346 | { return ResultInternalError; }); | ||
| 347 | ASSERT_OR_EXECUTE(buffers[1].BufferType == SECBUFFER_DATA, | ||
| 348 | { return ResultInternalError; }); | ||
| 349 | ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER, | ||
| 350 | { return ResultInternalError; }); | ||
| 351 | cleartext_read_buf.assign(static_cast<u8*>(buffers[1].pvBuffer), | ||
| 352 | static_cast<u8*>(buffers[1].pvBuffer) + | ||
| 353 | buffers[1].cbBuffer); | ||
| 354 | if (buffers[3].BufferType == SECBUFFER_EXTRA) { | ||
| 355 | ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf.size()); | ||
| 356 | ciphertext_read_buf.erase(ciphertext_read_buf.begin(), | ||
| 357 | ciphertext_read_buf.end() - buffers[3].cbBuffer); | ||
| 358 | } else { | ||
| 359 | ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY); | ||
| 360 | ciphertext_read_buf.clear(); | ||
| 361 | } | ||
| 362 | continue; | ||
| 363 | case SEC_E_INCOMPLETE_MESSAGE: | ||
| 364 | break; | ||
| 365 | case SEC_I_CONTEXT_EXPIRED: | ||
| 366 | // Server hung up by sending close_notify. | ||
| 367 | got_read_eof = true; | ||
| 368 | return size_t(0); | ||
| 369 | default: | ||
| 370 | LOG_ERROR(Service_SSL, "DecryptMessage failed: {}", | ||
| 371 | Common::NativeErrorToString(ret)); | ||
| 372 | return ResultInternalError; | ||
| 373 | } | ||
| 374 | } | ||
| 375 | const Result r = FillCiphertextReadBuf(); | ||
| 376 | if (r != ResultSuccess) { | ||
| 377 | return r; | ||
| 378 | } | ||
| 379 | if (ciphertext_read_buf.empty()) { | ||
| 380 | got_read_eof = true; | ||
| 381 | return size_t(0); | ||
| 382 | } | ||
| 383 | } | ||
| 384 | } | ||
| 385 | |||
| 386 | ResultVal<size_t> Write(std::span<const u8> data) override { | ||
| 387 | if (handshake_state != HandshakeState::Connected) { | ||
| 388 | LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake"); | ||
| 389 | return ResultInternalError; | ||
| 390 | } | ||
| 391 | if (data.size() == 0) { | ||
| 392 | return size_t(0); | ||
| 393 | } | ||
| 394 | data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes.cbMaximumMessage)); | ||
| 395 | if (!cleartext_write_buf.empty()) { | ||
| 396 | // Already in the middle of a write. It wouldn't make sense to not | ||
| 397 | // finish sending the entire buffer since TLS has | ||
| 398 | // header/MAC/padding/etc. | ||
| 399 | if (data.size() != cleartext_write_buf.size() || | ||
| 400 | std::memcmp(data.data(), cleartext_write_buf.data(), data.size())) { | ||
| 401 | LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer"); | ||
| 402 | return ResultInternalError; | ||
| 403 | } | ||
| 404 | return WriteAlreadyEncryptedData(); | ||
| 405 | } else { | ||
| 406 | cleartext_write_buf.assign(data.begin(), data.end()); | ||
| 407 | } | ||
| 408 | |||
| 409 | std::vector<u8> header_buf(stream_sizes.cbHeader, 0); | ||
| 410 | std::vector<u8> tmp_data_buf = cleartext_write_buf; | ||
| 411 | std::vector<u8> trailer_buf(stream_sizes.cbTrailer, 0); | ||
| 412 | |||
| 413 | std::array<SecBuffer, 3> buffers{{ | ||
| 414 | { | ||
| 415 | .cbBuffer = stream_sizes.cbHeader, | ||
| 416 | .BufferType = SECBUFFER_STREAM_HEADER, | ||
| 417 | .pvBuffer = header_buf.data(), | ||
| 418 | }, | ||
| 419 | { | ||
| 420 | .cbBuffer = static_cast<unsigned long>(tmp_data_buf.size()), | ||
| 421 | .BufferType = SECBUFFER_DATA, | ||
| 422 | .pvBuffer = tmp_data_buf.data(), | ||
| 423 | }, | ||
| 424 | { | ||
| 425 | .cbBuffer = stream_sizes.cbTrailer, | ||
| 426 | .BufferType = SECBUFFER_STREAM_TRAILER, | ||
| 427 | .pvBuffer = trailer_buf.data(), | ||
| 428 | }, | ||
| 429 | }}; | ||
| 430 | ASSERT_OR_EXECUTE_MSG( | ||
| 431 | buffers[1].cbBuffer == tmp_data_buf.size(), { return ResultInternalError; }, | ||
| 432 | "temp buffer too large"); | ||
| 433 | SecBufferDesc desc{ | ||
| 434 | .ulVersion = SECBUFFER_VERSION, | ||
| 435 | .cBuffers = static_cast<unsigned long>(buffers.size()), | ||
| 436 | .pBuffers = buffers.data(), | ||
| 437 | }; | ||
| 438 | |||
| 439 | const SECURITY_STATUS ret = EncryptMessage(&ctxt, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0); | ||
| 440 | if (ret != SEC_E_OK) { | ||
| 441 | LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret)); | ||
| 442 | return ResultInternalError; | ||
| 443 | } | ||
| 444 | ciphertext_write_buf.insert(ciphertext_write_buf.end(), header_buf.begin(), | ||
| 445 | header_buf.end()); | ||
| 446 | ciphertext_write_buf.insert(ciphertext_write_buf.end(), tmp_data_buf.begin(), | ||
| 447 | tmp_data_buf.end()); | ||
| 448 | ciphertext_write_buf.insert(ciphertext_write_buf.end(), trailer_buf.begin(), | ||
| 449 | trailer_buf.end()); | ||
| 450 | return WriteAlreadyEncryptedData(); | ||
| 451 | } | ||
| 452 | |||
| 453 | ResultVal<size_t> WriteAlreadyEncryptedData() { | ||
| 454 | const Result r = FlushCiphertextWriteBuf(); | ||
| 455 | if (r != ResultSuccess) { | ||
| 456 | return r; | ||
| 457 | } | ||
| 458 | // write buf is empty | ||
| 459 | const size_t cleartext_bytes_written = cleartext_write_buf.size(); | ||
| 460 | cleartext_write_buf.clear(); | ||
| 461 | return cleartext_bytes_written; | ||
| 462 | } | ||
| 463 | |||
| 464 | ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { | ||
| 465 | PCCERT_CONTEXT returned_cert = nullptr; | ||
| 466 | const SECURITY_STATUS ret = | ||
| 467 | QueryContextAttributes(&ctxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert); | ||
| 468 | if (ret != SEC_E_OK) { | ||
| 469 | LOG_ERROR(Service_SSL, | ||
| 470 | "QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}", | ||
| 471 | Common::NativeErrorToString(ret)); | ||
| 472 | return ResultInternalError; | ||
| 473 | } | ||
| 474 | PCCERT_CONTEXT some_cert = nullptr; | ||
| 475 | std::vector<std::vector<u8>> certs; | ||
| 476 | while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) { | ||
| 477 | certs.emplace_back(static_cast<u8*>(some_cert->pbCertEncoded), | ||
| 478 | static_cast<u8*>(some_cert->pbCertEncoded) + | ||
| 479 | some_cert->cbCertEncoded); | ||
| 480 | } | ||
| 481 | std::reverse(certs.begin(), | ||
| 482 | certs.end()); // Windows returns certs in reverse order from what we want | ||
| 483 | CertFreeCertificateContext(returned_cert); | ||
| 484 | return certs; | ||
| 485 | } | ||
| 486 | |||
| 487 | ~SSLConnectionBackendSchannel() { | ||
| 488 | if (handshake_state != HandshakeState::Initial) { | ||
| 489 | DeleteSecurityContext(&ctxt); | ||
| 490 | } | ||
| 491 | } | ||
| 492 | |||
| 493 | enum class HandshakeState { | ||
| 494 | // Haven't called anything yet. | ||
| 495 | Initial, | ||
| 496 | // `SEC_I_CONTINUE_NEEDED` was returned by | ||
| 497 | // `InitializeSecurityContext`; must finish sending data (if any) in | ||
| 498 | // the write buffer, then read at least one byte before calling | ||
| 499 | // `InitializeSecurityContext` again. | ||
| 500 | ContinueNeeded, | ||
| 501 | // `SEC_E_INCOMPLETE_MESSAGE` was returned by | ||
| 502 | // `InitializeSecurityContext`; hopefully the write buffer is empty; | ||
| 503 | // must read at least one byte before calling | ||
| 504 | // `InitializeSecurityContext` again. | ||
| 505 | IncompleteMessage, | ||
| 506 | // `SEC_E_OK` was returned by `InitializeSecurityContext`; must | ||
| 507 | // finish sending data in the write buffer before having `DoHandshake` | ||
| 508 | // report success. | ||
| 509 | DoneAfterFlush, | ||
| 510 | // We finished the above and are now connected. At this point, writing | ||
| 511 | // and reading are separate 'state machines' represented by the | ||
| 512 | // nonemptiness of the ciphertext and cleartext read and write buffers. | ||
| 513 | Connected, | ||
| 514 | // Another error was returned and we shouldn't allow initialization | ||
| 515 | // to continue. | ||
| 516 | Error, | ||
| 517 | } handshake_state = HandshakeState::Initial; | ||
| 518 | |||
| 519 | CtxtHandle ctxt; | ||
| 520 | SecPkgContext_StreamSizes stream_sizes; | ||
| 521 | |||
| 522 | std::shared_ptr<Network::SocketBase> socket; | ||
| 523 | std::optional<std::string> hostname; | ||
| 524 | |||
| 525 | std::vector<u8> ciphertext_read_buf; | ||
| 526 | std::vector<u8> ciphertext_write_buf; | ||
| 527 | std::vector<u8> cleartext_read_buf; | ||
| 528 | std::vector<u8> cleartext_write_buf; | ||
| 529 | |||
| 530 | bool got_read_eof = false; | ||
| 531 | size_t read_buf_fill_size = 0; | ||
| 532 | }; | ||
| 533 | |||
| 534 | ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { | ||
| 535 | auto conn = std::make_unique<SSLConnectionBackendSchannel>(); | ||
| 536 | const Result res = conn->Init(); | ||
| 537 | if (res.IsFailure()) { | ||
| 538 | return res; | ||
| 539 | } | ||
| 540 | return conn; | ||
| 541 | } | ||
| 542 | |||
| 543 | } // namespace Service::SSL | ||
diff --git a/src/core/hle/service/ssl/ssl_backend_securetransport.cpp b/src/core/hle/service/ssl/ssl_backend_securetransport.cpp new file mode 100644 index 000000000..be40a5aeb --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_securetransport.cpp | |||
| @@ -0,0 +1,219 @@ | |||
| 1 | // SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project | ||
| 2 | // SPDX-License-Identifier: GPL-2.0-or-later | ||
| 3 | |||
| 4 | #include "core/hle/service/ssl/ssl_backend.h" | ||
| 5 | #include "core/internal_network/network.h" | ||
| 6 | #include "core/internal_network/sockets.h" | ||
| 7 | |||
| 8 | #include <mutex> | ||
| 9 | |||
| 10 | #include <Security/SecureTransport.h> | ||
| 11 | |||
| 12 | // SecureTransport has been deprecated in its entirety in favor of | ||
| 13 | // Network.framework, but that does not allow layering TLS on top of an | ||
| 14 | // arbitrary socket. | ||
| 15 | #pragma GCC diagnostic ignored "-Wdeprecated-declarations" | ||
| 16 | |||
| 17 | namespace { | ||
| 18 | |||
| 19 | template <typename T> | ||
| 20 | struct CFReleaser { | ||
| 21 | T ptr; | ||
| 22 | |||
| 23 | YUZU_NON_COPYABLE(CFReleaser); | ||
| 24 | constexpr CFReleaser() : ptr(nullptr) {} | ||
| 25 | constexpr CFReleaser(T ptr) : ptr(ptr) {} | ||
| 26 | constexpr operator T() { | ||
| 27 | return ptr; | ||
| 28 | } | ||
| 29 | ~CFReleaser() { | ||
| 30 | if (ptr) { | ||
| 31 | CFRelease(ptr); | ||
| 32 | } | ||
| 33 | } | ||
| 34 | }; | ||
| 35 | |||
| 36 | std::string CFStringToString(CFStringRef cfstr) { | ||
| 37 | CFReleaser<CFDataRef> cfdata( | ||
| 38 | CFStringCreateExternalRepresentation(nullptr, cfstr, kCFStringEncodingUTF8, 0)); | ||
| 39 | ASSERT_OR_EXECUTE(cfdata, { return "???"; }); | ||
| 40 | return std::string(reinterpret_cast<const char*>(CFDataGetBytePtr(cfdata)), | ||
| 41 | CFDataGetLength(cfdata)); | ||
| 42 | } | ||
| 43 | |||
| 44 | std::string OSStatusToString(OSStatus status) { | ||
| 45 | CFReleaser<CFStringRef> cfstr(SecCopyErrorMessageString(status, nullptr)); | ||
| 46 | if (!cfstr) { | ||
| 47 | return "[unknown error]"; | ||
| 48 | } | ||
| 49 | return CFStringToString(cfstr); | ||
| 50 | } | ||
| 51 | |||
| 52 | } // namespace | ||
| 53 | |||
| 54 | namespace Service::SSL { | ||
| 55 | |||
| 56 | class SSLConnectionBackendSecureTransport final : public SSLConnectionBackend { | ||
| 57 | public: | ||
| 58 | Result Init() { | ||
| 59 | static std::once_flag once_flag; | ||
| 60 | std::call_once(once_flag, []() { | ||
| 61 | if (getenv("SSLKEYLOGFILE")) { | ||
| 62 | LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but SecureTransport does not " | ||
| 63 | "support exporting keys; not logging keys!"); | ||
| 64 | // Not fatal. | ||
| 65 | } | ||
| 66 | }); | ||
| 67 | |||
| 68 | context.ptr = SSLCreateContext(nullptr, kSSLClientSide, kSSLStreamType); | ||
| 69 | if (!context) { | ||
| 70 | LOG_ERROR(Service_SSL, "SSLCreateContext failed"); | ||
| 71 | return ResultInternalError; | ||
| 72 | } | ||
| 73 | |||
| 74 | OSStatus status; | ||
| 75 | if ((status = SSLSetIOFuncs(context, ReadCallback, WriteCallback)) || | ||
| 76 | (status = SSLSetConnection(context, this))) { | ||
| 77 | LOG_ERROR(Service_SSL, "SSLContext initialization failed: {}", | ||
| 78 | OSStatusToString(status)); | ||
| 79 | return ResultInternalError; | ||
| 80 | } | ||
| 81 | |||
| 82 | return ResultSuccess; | ||
| 83 | } | ||
| 84 | |||
| 85 | void SetSocket(std::shared_ptr<Network::SocketBase> in_socket) override { | ||
| 86 | socket = std::move(in_socket); | ||
| 87 | } | ||
| 88 | |||
| 89 | Result SetHostName(const std::string& hostname) override { | ||
| 90 | OSStatus status = SSLSetPeerDomainName(context, hostname.c_str(), hostname.size()); | ||
| 91 | if (status) { | ||
| 92 | LOG_ERROR(Service_SSL, "SSLSetPeerDomainName failed: {}", OSStatusToString(status)); | ||
| 93 | return ResultInternalError; | ||
| 94 | } | ||
| 95 | return ResultSuccess; | ||
| 96 | } | ||
| 97 | |||
| 98 | Result DoHandshake() override { | ||
| 99 | OSStatus status = SSLHandshake(context); | ||
| 100 | return HandleReturn("SSLHandshake", 0, status).Code(); | ||
| 101 | } | ||
| 102 | |||
| 103 | ResultVal<size_t> Read(std::span<u8> data) override { | ||
| 104 | size_t actual; | ||
| 105 | OSStatus status = SSLRead(context, data.data(), data.size(), &actual); | ||
| 106 | ; | ||
| 107 | return HandleReturn("SSLRead", actual, status); | ||
| 108 | } | ||
| 109 | |||
| 110 | ResultVal<size_t> Write(std::span<const u8> data) override { | ||
| 111 | size_t actual; | ||
| 112 | OSStatus status = SSLWrite(context, data.data(), data.size(), &actual); | ||
| 113 | ; | ||
| 114 | return HandleReturn("SSLWrite", actual, status); | ||
| 115 | } | ||
| 116 | |||
| 117 | ResultVal<size_t> HandleReturn(const char* what, size_t actual, OSStatus status) { | ||
| 118 | switch (status) { | ||
| 119 | case 0: | ||
| 120 | return actual; | ||
| 121 | case errSSLWouldBlock: | ||
| 122 | return ResultWouldBlock; | ||
| 123 | default: { | ||
| 124 | std::string reason; | ||
| 125 | if (got_read_eof) { | ||
| 126 | reason = "server hung up"; | ||
| 127 | } else { | ||
| 128 | reason = OSStatusToString(status); | ||
| 129 | } | ||
| 130 | LOG_ERROR(Service_SSL, "{} failed: {}", what, reason); | ||
| 131 | return ResultInternalError; | ||
| 132 | } | ||
| 133 | } | ||
| 134 | } | ||
| 135 | |||
| 136 | ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { | ||
| 137 | CFReleaser<SecTrustRef> trust; | ||
| 138 | OSStatus status = SSLCopyPeerTrust(context, &trust.ptr); | ||
| 139 | if (status) { | ||
| 140 | LOG_ERROR(Service_SSL, "SSLCopyPeerTrust failed: {}", OSStatusToString(status)); | ||
| 141 | return ResultInternalError; | ||
| 142 | } | ||
| 143 | std::vector<std::vector<u8>> ret; | ||
| 144 | for (CFIndex i = 0, count = SecTrustGetCertificateCount(trust); i < count; i++) { | ||
| 145 | SecCertificateRef cert = SecTrustGetCertificateAtIndex(trust, i); | ||
| 146 | CFReleaser<CFDataRef> data(SecCertificateCopyData(cert)); | ||
| 147 | ASSERT_OR_EXECUTE(data, { return ResultInternalError; }); | ||
| 148 | const u8* ptr = CFDataGetBytePtr(data); | ||
| 149 | ret.emplace_back(ptr, ptr + CFDataGetLength(data)); | ||
| 150 | } | ||
| 151 | return ret; | ||
| 152 | } | ||
| 153 | |||
| 154 | static OSStatus ReadCallback(SSLConnectionRef connection, void* data, size_t* dataLength) { | ||
| 155 | return ReadOrWriteCallback(connection, data, dataLength, true); | ||
| 156 | } | ||
| 157 | |||
| 158 | static OSStatus WriteCallback(SSLConnectionRef connection, const void* data, | ||
| 159 | size_t* dataLength) { | ||
| 160 | return ReadOrWriteCallback(connection, const_cast<void*>(data), dataLength, false); | ||
| 161 | } | ||
| 162 | |||
| 163 | static OSStatus ReadOrWriteCallback(SSLConnectionRef connection, void* data, size_t* dataLength, | ||
| 164 | bool is_read) { | ||
| 165 | auto self = | ||
| 166 | static_cast<SSLConnectionBackendSecureTransport*>(const_cast<void*>(connection)); | ||
| 167 | ASSERT_OR_EXECUTE_MSG( | ||
| 168 | self->socket, { return 0; }, "SecureTransport asked to {} but we have no socket", | ||
| 169 | is_read ? "read" : "write"); | ||
| 170 | |||
| 171 | // SecureTransport callbacks (unlike OpenSSL BIO callbacks) are | ||
| 172 | // expected to read/write the full requested dataLength or return an | ||
| 173 | // error, so we have to add a loop ourselves. | ||
| 174 | size_t requested_len = *dataLength; | ||
| 175 | size_t offset = 0; | ||
| 176 | while (offset < requested_len) { | ||
| 177 | std::span cur(reinterpret_cast<u8*>(data) + offset, requested_len - offset); | ||
| 178 | auto [actual, err] = is_read ? self->socket->Recv(0, cur) : self->socket->Send(cur, 0); | ||
| 179 | LOG_CRITICAL(Service_SSL, "op={}, offset={} actual={}/{} err={}", is_read, offset, | ||
| 180 | actual, cur.size(), static_cast<s32>(err)); | ||
| 181 | switch (err) { | ||
| 182 | case Network::Errno::SUCCESS: | ||
| 183 | offset += actual; | ||
| 184 | if (actual == 0) { | ||
| 185 | ASSERT(is_read); | ||
| 186 | self->got_read_eof = true; | ||
| 187 | return errSecEndOfData; | ||
| 188 | } | ||
| 189 | break; | ||
| 190 | case Network::Errno::AGAIN: | ||
| 191 | *dataLength = offset; | ||
| 192 | return errSSLWouldBlock; | ||
| 193 | default: | ||
| 194 | LOG_ERROR(Service_SSL, "Socket {} returned Network::Errno {}", | ||
| 195 | is_read ? "recv" : "send", err); | ||
| 196 | return errSecIO; | ||
| 197 | } | ||
| 198 | } | ||
| 199 | ASSERT(offset == requested_len); | ||
| 200 | return 0; | ||
| 201 | } | ||
| 202 | |||
| 203 | private: | ||
| 204 | CFReleaser<SSLContextRef> context = nullptr; | ||
| 205 | bool got_read_eof = false; | ||
| 206 | |||
| 207 | std::shared_ptr<Network::SocketBase> socket; | ||
| 208 | }; | ||
| 209 | |||
| 210 | ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { | ||
| 211 | auto conn = std::make_unique<SSLConnectionBackendSecureTransport>(); | ||
| 212 | const Result res = conn->Init(); | ||
| 213 | if (res.IsFailure()) { | ||
| 214 | return res; | ||
| 215 | } | ||
| 216 | return conn; | ||
| 217 | } | ||
| 218 | |||
| 219 | } // namespace Service::SSL | ||
diff --git a/src/core/internal_network/network.cpp b/src/core/internal_network/network.cpp index 75ac10a9c..28f89c599 100644 --- a/src/core/internal_network/network.cpp +++ b/src/core/internal_network/network.cpp | |||
| @@ -27,6 +27,7 @@ | |||
| 27 | 27 | ||
| 28 | #include "common/assert.h" | 28 | #include "common/assert.h" |
| 29 | #include "common/common_types.h" | 29 | #include "common/common_types.h" |
| 30 | #include "common/expected.h" | ||
| 30 | #include "common/logging/log.h" | 31 | #include "common/logging/log.h" |
| 31 | #include "common/settings.h" | 32 | #include "common/settings.h" |
| 32 | #include "core/internal_network/network.h" | 33 | #include "core/internal_network/network.h" |
| @@ -97,6 +98,8 @@ bool EnableNonBlock(SOCKET fd, bool enable) { | |||
| 97 | 98 | ||
| 98 | Errno TranslateNativeError(int e) { | 99 | Errno TranslateNativeError(int e) { |
| 99 | switch (e) { | 100 | switch (e) { |
| 101 | case 0: | ||
| 102 | return Errno::SUCCESS; | ||
| 100 | case WSAEBADF: | 103 | case WSAEBADF: |
| 101 | return Errno::BADF; | 104 | return Errno::BADF; |
| 102 | case WSAEINVAL: | 105 | case WSAEINVAL: |
| @@ -121,6 +124,8 @@ Errno TranslateNativeError(int e) { | |||
| 121 | return Errno::MSGSIZE; | 124 | return Errno::MSGSIZE; |
| 122 | case WSAETIMEDOUT: | 125 | case WSAETIMEDOUT: |
| 123 | return Errno::TIMEDOUT; | 126 | return Errno::TIMEDOUT; |
| 127 | case WSAEINPROGRESS: | ||
| 128 | return Errno::INPROGRESS; | ||
| 124 | default: | 129 | default: |
| 125 | UNIMPLEMENTED_MSG("Unimplemented errno={}", e); | 130 | UNIMPLEMENTED_MSG("Unimplemented errno={}", e); |
| 126 | return Errno::OTHER; | 131 | return Errno::OTHER; |
| @@ -195,6 +200,8 @@ bool EnableNonBlock(int fd, bool enable) { | |||
| 195 | 200 | ||
| 196 | Errno TranslateNativeError(int e) { | 201 | Errno TranslateNativeError(int e) { |
| 197 | switch (e) { | 202 | switch (e) { |
| 203 | case 0: | ||
| 204 | return Errno::SUCCESS; | ||
| 198 | case EBADF: | 205 | case EBADF: |
| 199 | return Errno::BADF; | 206 | return Errno::BADF; |
| 200 | case EINVAL: | 207 | case EINVAL: |
| @@ -219,8 +226,10 @@ Errno TranslateNativeError(int e) { | |||
| 219 | return Errno::MSGSIZE; | 226 | return Errno::MSGSIZE; |
| 220 | case ETIMEDOUT: | 227 | case ETIMEDOUT: |
| 221 | return Errno::TIMEDOUT; | 228 | return Errno::TIMEDOUT; |
| 229 | case EINPROGRESS: | ||
| 230 | return Errno::INPROGRESS; | ||
| 222 | default: | 231 | default: |
| 223 | UNIMPLEMENTED_MSG("Unimplemented errno={}", e); | 232 | UNIMPLEMENTED_MSG("Unimplemented errno={} ({})", e, strerror(e)); |
| 224 | return Errno::OTHER; | 233 | return Errno::OTHER; |
| 225 | } | 234 | } |
| 226 | } | 235 | } |
| @@ -234,15 +243,84 @@ Errno GetAndLogLastError() { | |||
| 234 | int e = errno; | 243 | int e = errno; |
| 235 | #endif | 244 | #endif |
| 236 | const Errno err = TranslateNativeError(e); | 245 | const Errno err = TranslateNativeError(e); |
| 237 | if (err == Errno::AGAIN || err == Errno::TIMEDOUT) { | 246 | if (err == Errno::AGAIN || err == Errno::TIMEDOUT || err == Errno::INPROGRESS) { |
| 247 | // These happen during normal operation, so only log them at debug level. | ||
| 248 | LOG_DEBUG(Network, "Socket operation error: {}", Common::NativeErrorToString(e)); | ||
| 238 | return err; | 249 | return err; |
| 239 | } | 250 | } |
| 240 | LOG_ERROR(Network, "Socket operation error: {}", Common::NativeErrorToString(e)); | 251 | LOG_ERROR(Network, "Socket operation error: {}", Common::NativeErrorToString(e)); |
| 241 | return err; | 252 | return err; |
| 242 | } | 253 | } |
| 243 | 254 | ||
| 244 | int TranslateDomain(Domain domain) { | 255 | GetAddrInfoError TranslateGetAddrInfoErrorFromNative(int gai_err) { |
| 256 | switch (gai_err) { | ||
| 257 | case 0: | ||
| 258 | return GetAddrInfoError::SUCCESS; | ||
| 259 | #ifdef EAI_ADDRFAMILY | ||
| 260 | case EAI_ADDRFAMILY: | ||
| 261 | return GetAddrInfoError::ADDRFAMILY; | ||
| 262 | #endif | ||
| 263 | case EAI_AGAIN: | ||
| 264 | return GetAddrInfoError::AGAIN; | ||
| 265 | case EAI_BADFLAGS: | ||
| 266 | return GetAddrInfoError::BADFLAGS; | ||
| 267 | case EAI_FAIL: | ||
| 268 | return GetAddrInfoError::FAIL; | ||
| 269 | case EAI_FAMILY: | ||
| 270 | return GetAddrInfoError::FAMILY; | ||
| 271 | case EAI_MEMORY: | ||
| 272 | return GetAddrInfoError::MEMORY; | ||
| 273 | case EAI_NONAME: | ||
| 274 | return GetAddrInfoError::NONAME; | ||
| 275 | case EAI_SERVICE: | ||
| 276 | return GetAddrInfoError::SERVICE; | ||
| 277 | case EAI_SOCKTYPE: | ||
| 278 | return GetAddrInfoError::SOCKTYPE; | ||
| 279 | // These codes may not be defined on all systems: | ||
| 280 | #ifdef EAI_SYSTEM | ||
| 281 | case EAI_SYSTEM: | ||
| 282 | return GetAddrInfoError::SYSTEM; | ||
| 283 | #endif | ||
| 284 | #ifdef EAI_BADHINTS | ||
| 285 | case EAI_BADHINTS: | ||
| 286 | return GetAddrInfoError::BADHINTS; | ||
| 287 | #endif | ||
| 288 | #ifdef EAI_PROTOCOL | ||
| 289 | case EAI_PROTOCOL: | ||
| 290 | return GetAddrInfoError::PROTOCOL; | ||
| 291 | #endif | ||
| 292 | #ifdef EAI_OVERFLOW | ||
| 293 | case EAI_OVERFLOW: | ||
| 294 | return GetAddrInfoError::OVERFLOW_; | ||
| 295 | #endif | ||
| 296 | default: | ||
| 297 | #ifdef EAI_NODATA | ||
| 298 | // This can't be a case statement because it would create a duplicate | ||
| 299 | // case on Windows where EAI_NODATA is an alias for EAI_NONAME. | ||
| 300 | if (gai_err == EAI_NODATA) { | ||
| 301 | return GetAddrInfoError::NODATA; | ||
| 302 | } | ||
| 303 | #endif | ||
| 304 | return GetAddrInfoError::OTHER; | ||
| 305 | } | ||
| 306 | } | ||
| 307 | |||
| 308 | Domain TranslateDomainFromNative(int domain) { | ||
| 309 | switch (domain) { | ||
| 310 | case 0: | ||
| 311 | return Domain::Unspecified; | ||
| 312 | case AF_INET: | ||
| 313 | return Domain::INET; | ||
| 314 | default: | ||
| 315 | UNIMPLEMENTED_MSG("Unhandled domain={}", domain); | ||
| 316 | return Domain::INET; | ||
| 317 | } | ||
| 318 | } | ||
| 319 | |||
| 320 | int TranslateDomainToNative(Domain domain) { | ||
| 245 | switch (domain) { | 321 | switch (domain) { |
| 322 | case Domain::Unspecified: | ||
| 323 | return 0; | ||
| 246 | case Domain::INET: | 324 | case Domain::INET: |
| 247 | return AF_INET; | 325 | return AF_INET; |
| 248 | default: | 326 | default: |
| @@ -251,20 +329,58 @@ int TranslateDomain(Domain domain) { | |||
| 251 | } | 329 | } |
| 252 | } | 330 | } |
| 253 | 331 | ||
| 254 | int TranslateType(Type type) { | 332 | Type TranslateTypeFromNative(int type) { |
| 333 | switch (type) { | ||
| 334 | case 0: | ||
| 335 | return Type::Unspecified; | ||
| 336 | case SOCK_STREAM: | ||
| 337 | return Type::STREAM; | ||
| 338 | case SOCK_DGRAM: | ||
| 339 | return Type::DGRAM; | ||
| 340 | case SOCK_RAW: | ||
| 341 | return Type::RAW; | ||
| 342 | case SOCK_SEQPACKET: | ||
| 343 | return Type::SEQPACKET; | ||
| 344 | default: | ||
| 345 | UNIMPLEMENTED_MSG("Unimplemented type={}", type); | ||
| 346 | return Type::STREAM; | ||
| 347 | } | ||
| 348 | } | ||
| 349 | |||
| 350 | int TranslateTypeToNative(Type type) { | ||
| 255 | switch (type) { | 351 | switch (type) { |
| 352 | case Type::Unspecified: | ||
| 353 | return 0; | ||
| 256 | case Type::STREAM: | 354 | case Type::STREAM: |
| 257 | return SOCK_STREAM; | 355 | return SOCK_STREAM; |
| 258 | case Type::DGRAM: | 356 | case Type::DGRAM: |
| 259 | return SOCK_DGRAM; | 357 | return SOCK_DGRAM; |
| 358 | case Type::RAW: | ||
| 359 | return SOCK_RAW; | ||
| 260 | default: | 360 | default: |
| 261 | UNIMPLEMENTED_MSG("Unimplemented type={}", type); | 361 | UNIMPLEMENTED_MSG("Unimplemented type={}", type); |
| 262 | return 0; | 362 | return 0; |
| 263 | } | 363 | } |
| 264 | } | 364 | } |
| 265 | 365 | ||
| 266 | int TranslateProtocol(Protocol protocol) { | 366 | Protocol TranslateProtocolFromNative(int protocol) { |
| 367 | switch (protocol) { | ||
| 368 | case 0: | ||
| 369 | return Protocol::Unspecified; | ||
| 370 | case IPPROTO_TCP: | ||
| 371 | return Protocol::TCP; | ||
| 372 | case IPPROTO_UDP: | ||
| 373 | return Protocol::UDP; | ||
| 374 | default: | ||
| 375 | UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol); | ||
| 376 | return Protocol::Unspecified; | ||
| 377 | } | ||
| 378 | } | ||
| 379 | |||
| 380 | int TranslateProtocolToNative(Protocol protocol) { | ||
| 267 | switch (protocol) { | 381 | switch (protocol) { |
| 382 | case Protocol::Unspecified: | ||
| 383 | return 0; | ||
| 268 | case Protocol::TCP: | 384 | case Protocol::TCP: |
| 269 | return IPPROTO_TCP; | 385 | return IPPROTO_TCP; |
| 270 | case Protocol::UDP: | 386 | case Protocol::UDP: |
| @@ -275,21 +391,10 @@ int TranslateProtocol(Protocol protocol) { | |||
| 275 | } | 391 | } |
| 276 | } | 392 | } |
| 277 | 393 | ||
| 278 | SockAddrIn TranslateToSockAddrIn(sockaddr input_) { | 394 | SockAddrIn TranslateToSockAddrIn(sockaddr_in input, size_t input_len) { |
| 279 | sockaddr_in input; | ||
| 280 | std::memcpy(&input, &input_, sizeof(input)); | ||
| 281 | |||
| 282 | SockAddrIn result; | 395 | SockAddrIn result; |
| 283 | 396 | ||
| 284 | switch (input.sin_family) { | 397 | result.family = TranslateDomainFromNative(input.sin_family); |
| 285 | case AF_INET: | ||
| 286 | result.family = Domain::INET; | ||
| 287 | break; | ||
| 288 | default: | ||
| 289 | UNIMPLEMENTED_MSG("Unhandled sockaddr family={}", input.sin_family); | ||
| 290 | result.family = Domain::INET; | ||
| 291 | break; | ||
| 292 | } | ||
| 293 | 398 | ||
| 294 | result.portno = ntohs(input.sin_port); | 399 | result.portno = ntohs(input.sin_port); |
| 295 | 400 | ||
| @@ -301,22 +406,33 @@ SockAddrIn TranslateToSockAddrIn(sockaddr input_) { | |||
| 301 | short TranslatePollEvents(PollEvents events) { | 406 | short TranslatePollEvents(PollEvents events) { |
| 302 | short result = 0; | 407 | short result = 0; |
| 303 | 408 | ||
| 304 | if (True(events & PollEvents::In)) { | 409 | const auto translate = [&result, &events](PollEvents guest, short host) { |
| 305 | events &= ~PollEvents::In; | 410 | if (True(events & guest)) { |
| 306 | result |= POLLIN; | 411 | events &= ~guest; |
| 307 | } | 412 | result |= host; |
| 308 | if (True(events & PollEvents::Pri)) { | 413 | } |
| 309 | events &= ~PollEvents::Pri; | 414 | }; |
| 415 | |||
| 416 | translate(PollEvents::In, POLLIN); | ||
| 417 | translate(PollEvents::Pri, POLLPRI); | ||
| 418 | translate(PollEvents::Out, POLLOUT); | ||
| 419 | translate(PollEvents::Err, POLLERR); | ||
| 420 | translate(PollEvents::Hup, POLLHUP); | ||
| 421 | translate(PollEvents::Nval, POLLNVAL); | ||
| 422 | translate(PollEvents::RdNorm, POLLRDNORM); | ||
| 423 | translate(PollEvents::RdBand, POLLRDBAND); | ||
| 424 | translate(PollEvents::WrBand, POLLWRBAND); | ||
| 425 | |||
| 310 | #ifdef _WIN32 | 426 | #ifdef _WIN32 |
| 311 | LOG_WARNING(Service, "Winsock doesn't support POLLPRI"); | 427 | short allowed_events = POLLRDBAND | POLLRDNORM | POLLWRNORM; |
| 312 | #else | 428 | // Unlike poll on other OSes, WSAPoll will complain if any other flags are set on input. |
| 313 | result |= POLLPRI; | 429 | if (result & ~allowed_events) { |
| 430 | LOG_DEBUG(Network, | ||
| 431 | "Removing WSAPoll input events 0x{:x} because Windows doesn't support them", | ||
| 432 | result & ~allowed_events); | ||
| 433 | } | ||
| 434 | result &= allowed_events; | ||
| 314 | #endif | 435 | #endif |
| 315 | } | ||
| 316 | if (True(events & PollEvents::Out)) { | ||
| 317 | events &= ~PollEvents::Out; | ||
| 318 | result |= POLLOUT; | ||
| 319 | } | ||
| 320 | 436 | ||
| 321 | UNIMPLEMENTED_IF_MSG((u16)events != 0, "Unhandled guest events=0x{:x}", (u16)events); | 437 | UNIMPLEMENTED_IF_MSG((u16)events != 0, "Unhandled guest events=0x{:x}", (u16)events); |
| 322 | 438 | ||
| @@ -337,6 +453,10 @@ PollEvents TranslatePollRevents(short revents) { | |||
| 337 | translate(POLLOUT, PollEvents::Out); | 453 | translate(POLLOUT, PollEvents::Out); |
| 338 | translate(POLLERR, PollEvents::Err); | 454 | translate(POLLERR, PollEvents::Err); |
| 339 | translate(POLLHUP, PollEvents::Hup); | 455 | translate(POLLHUP, PollEvents::Hup); |
| 456 | translate(POLLNVAL, PollEvents::Nval); | ||
| 457 | translate(POLLRDNORM, PollEvents::RdNorm); | ||
| 458 | translate(POLLRDBAND, PollEvents::RdBand); | ||
| 459 | translate(POLLWRBAND, PollEvents::WrBand); | ||
| 340 | 460 | ||
| 341 | UNIMPLEMENTED_IF_MSG(revents != 0, "Unhandled host revents=0x{:x}", revents); | 461 | UNIMPLEMENTED_IF_MSG(revents != 0, "Unhandled host revents=0x{:x}", revents); |
| 342 | 462 | ||
| @@ -360,12 +480,51 @@ std::optional<IPv4Address> GetHostIPv4Address() { | |||
| 360 | return {}; | 480 | return {}; |
| 361 | } | 481 | } |
| 362 | 482 | ||
| 363 | std::array<char, 16> ip_addr = {}; | ||
| 364 | ASSERT(inet_ntop(AF_INET, &network_interface->ip_address, ip_addr.data(), sizeof(ip_addr)) != | ||
| 365 | nullptr); | ||
| 366 | return TranslateIPv4(network_interface->ip_address); | 483 | return TranslateIPv4(network_interface->ip_address); |
| 367 | } | 484 | } |
| 368 | 485 | ||
| 486 | std::string IPv4AddressToString(IPv4Address ip_addr) { | ||
| 487 | std::array<char, INET_ADDRSTRLEN> buf = {}; | ||
| 488 | ASSERT(inet_ntop(AF_INET, &ip_addr, buf.data(), sizeof(buf)) == buf.data()); | ||
| 489 | return std::string(buf.data()); | ||
| 490 | } | ||
| 491 | |||
| 492 | u32 IPv4AddressToInteger(IPv4Address ip_addr) { | ||
| 493 | return static_cast<u32>(ip_addr[0]) << 24 | static_cast<u32>(ip_addr[1]) << 16 | | ||
| 494 | static_cast<u32>(ip_addr[2]) << 8 | static_cast<u32>(ip_addr[3]); | ||
| 495 | } | ||
| 496 | |||
| 497 | Common::Expected<std::vector<AddrInfo>, GetAddrInfoError> GetAddressInfo( | ||
| 498 | const std::string& host, const std::optional<std::string>& service) { | ||
| 499 | addrinfo hints{}; | ||
| 500 | hints.ai_family = AF_INET; // Switch only supports IPv4. | ||
| 501 | addrinfo* addrinfo; | ||
| 502 | s32 gai_err = getaddrinfo(host.c_str(), service.has_value() ? service->c_str() : nullptr, | ||
| 503 | &hints, &addrinfo); | ||
| 504 | if (gai_err != 0) { | ||
| 505 | return Common::Unexpected(TranslateGetAddrInfoErrorFromNative(gai_err)); | ||
| 506 | } | ||
| 507 | std::vector<AddrInfo> ret; | ||
| 508 | for (auto* current = addrinfo; current; current = current->ai_next) { | ||
| 509 | // We should only get AF_INET results due to the hints value. | ||
| 510 | ASSERT_OR_EXECUTE(addrinfo->ai_family == AF_INET && | ||
| 511 | addrinfo->ai_addrlen == sizeof(sockaddr_in), | ||
| 512 | continue;); | ||
| 513 | |||
| 514 | AddrInfo& out = ret.emplace_back(); | ||
| 515 | out.family = TranslateDomainFromNative(current->ai_family); | ||
| 516 | out.socket_type = TranslateTypeFromNative(current->ai_socktype); | ||
| 517 | out.protocol = TranslateProtocolFromNative(current->ai_protocol); | ||
| 518 | out.addr = TranslateToSockAddrIn(*reinterpret_cast<sockaddr_in*>(current->ai_addr), | ||
| 519 | current->ai_addrlen); | ||
| 520 | if (current->ai_canonname != nullptr) { | ||
| 521 | out.canon_name = current->ai_canonname; | ||
| 522 | } | ||
| 523 | } | ||
| 524 | freeaddrinfo(addrinfo); | ||
| 525 | return ret; | ||
| 526 | } | ||
| 527 | |||
| 369 | std::pair<s32, Errno> Poll(std::vector<PollFD>& pollfds, s32 timeout) { | 528 | std::pair<s32, Errno> Poll(std::vector<PollFD>& pollfds, s32 timeout) { |
| 370 | const size_t num = pollfds.size(); | 529 | const size_t num = pollfds.size(); |
| 371 | 530 | ||
| @@ -411,9 +570,21 @@ Socket::Socket(Socket&& rhs) noexcept { | |||
| 411 | } | 570 | } |
| 412 | 571 | ||
| 413 | template <typename T> | 572 | template <typename T> |
| 414 | Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) { | 573 | std::pair<T, Errno> Socket::GetSockOpt(SOCKET fd_so, int option) { |
| 574 | T value{}; | ||
| 575 | socklen_t len = sizeof(value); | ||
| 576 | const int result = getsockopt(fd_so, SOL_SOCKET, option, reinterpret_cast<char*>(&value), &len); | ||
| 577 | if (result != SOCKET_ERROR) { | ||
| 578 | ASSERT(len == sizeof(value)); | ||
| 579 | return {value, Errno::SUCCESS}; | ||
| 580 | } | ||
| 581 | return {value, GetAndLogLastError()}; | ||
| 582 | } | ||
| 583 | |||
| 584 | template <typename T> | ||
| 585 | Errno Socket::SetSockOpt(SOCKET fd_so, int option, T value) { | ||
| 415 | const int result = | 586 | const int result = |
| 416 | setsockopt(fd_, SOL_SOCKET, option, reinterpret_cast<const char*>(&value), sizeof(value)); | 587 | setsockopt(fd_so, SOL_SOCKET, option, reinterpret_cast<const char*>(&value), sizeof(value)); |
| 417 | if (result != SOCKET_ERROR) { | 588 | if (result != SOCKET_ERROR) { |
| 418 | return Errno::SUCCESS; | 589 | return Errno::SUCCESS; |
| 419 | } | 590 | } |
| @@ -421,7 +592,8 @@ Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) { | |||
| 421 | } | 592 | } |
| 422 | 593 | ||
| 423 | Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) { | 594 | Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) { |
| 424 | fd = socket(TranslateDomain(domain), TranslateType(type), TranslateProtocol(protocol)); | 595 | fd = socket(TranslateDomainToNative(domain), TranslateTypeToNative(type), |
| 596 | TranslateProtocolToNative(protocol)); | ||
| 425 | if (fd != INVALID_SOCKET) { | 597 | if (fd != INVALID_SOCKET) { |
| 426 | return Errno::SUCCESS; | 598 | return Errno::SUCCESS; |
| 427 | } | 599 | } |
| @@ -430,19 +602,17 @@ Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) { | |||
| 430 | } | 602 | } |
| 431 | 603 | ||
| 432 | std::pair<SocketBase::AcceptResult, Errno> Socket::Accept() { | 604 | std::pair<SocketBase::AcceptResult, Errno> Socket::Accept() { |
| 433 | sockaddr addr; | 605 | sockaddr_in addr; |
| 434 | socklen_t addrlen = sizeof(addr); | 606 | socklen_t addrlen = sizeof(addr); |
| 435 | const SOCKET new_socket = accept(fd, &addr, &addrlen); | 607 | const SOCKET new_socket = accept(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen); |
| 436 | 608 | ||
| 437 | if (new_socket == INVALID_SOCKET) { | 609 | if (new_socket == INVALID_SOCKET) { |
| 438 | return {AcceptResult{}, GetAndLogLastError()}; | 610 | return {AcceptResult{}, GetAndLogLastError()}; |
| 439 | } | 611 | } |
| 440 | 612 | ||
| 441 | ASSERT(addrlen == sizeof(sockaddr_in)); | ||
| 442 | |||
| 443 | AcceptResult result{ | 613 | AcceptResult result{ |
| 444 | .socket = std::make_unique<Socket>(new_socket), | 614 | .socket = std::make_unique<Socket>(new_socket), |
| 445 | .sockaddr_in = TranslateToSockAddrIn(addr), | 615 | .sockaddr_in = TranslateToSockAddrIn(addr, addrlen), |
| 446 | }; | 616 | }; |
| 447 | 617 | ||
| 448 | return {std::move(result), Errno::SUCCESS}; | 618 | return {std::move(result), Errno::SUCCESS}; |
| @@ -458,25 +628,23 @@ Errno Socket::Connect(SockAddrIn addr_in) { | |||
| 458 | } | 628 | } |
| 459 | 629 | ||
| 460 | std::pair<SockAddrIn, Errno> Socket::GetPeerName() { | 630 | std::pair<SockAddrIn, Errno> Socket::GetPeerName() { |
| 461 | sockaddr addr; | 631 | sockaddr_in addr; |
| 462 | socklen_t addrlen = sizeof(addr); | 632 | socklen_t addrlen = sizeof(addr); |
| 463 | if (getpeername(fd, &addr, &addrlen) == SOCKET_ERROR) { | 633 | if (getpeername(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen) == SOCKET_ERROR) { |
| 464 | return {SockAddrIn{}, GetAndLogLastError()}; | 634 | return {SockAddrIn{}, GetAndLogLastError()}; |
| 465 | } | 635 | } |
| 466 | 636 | ||
| 467 | ASSERT(addrlen == sizeof(sockaddr_in)); | 637 | return {TranslateToSockAddrIn(addr, addrlen), Errno::SUCCESS}; |
| 468 | return {TranslateToSockAddrIn(addr), Errno::SUCCESS}; | ||
| 469 | } | 638 | } |
| 470 | 639 | ||
| 471 | std::pair<SockAddrIn, Errno> Socket::GetSockName() { | 640 | std::pair<SockAddrIn, Errno> Socket::GetSockName() { |
| 472 | sockaddr addr; | 641 | sockaddr_in addr; |
| 473 | socklen_t addrlen = sizeof(addr); | 642 | socklen_t addrlen = sizeof(addr); |
| 474 | if (getsockname(fd, &addr, &addrlen) == SOCKET_ERROR) { | 643 | if (getsockname(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen) == SOCKET_ERROR) { |
| 475 | return {SockAddrIn{}, GetAndLogLastError()}; | 644 | return {SockAddrIn{}, GetAndLogLastError()}; |
| 476 | } | 645 | } |
| 477 | 646 | ||
| 478 | ASSERT(addrlen == sizeof(sockaddr_in)); | 647 | return {TranslateToSockAddrIn(addr, addrlen), Errno::SUCCESS}; |
| 479 | return {TranslateToSockAddrIn(addr), Errno::SUCCESS}; | ||
| 480 | } | 648 | } |
| 481 | 649 | ||
| 482 | Errno Socket::Bind(SockAddrIn addr) { | 650 | Errno Socket::Bind(SockAddrIn addr) { |
| @@ -519,7 +687,7 @@ Errno Socket::Shutdown(ShutdownHow how) { | |||
| 519 | return GetAndLogLastError(); | 687 | return GetAndLogLastError(); |
| 520 | } | 688 | } |
| 521 | 689 | ||
| 522 | std::pair<s32, Errno> Socket::Recv(int flags, std::vector<u8>& message) { | 690 | std::pair<s32, Errno> Socket::Recv(int flags, std::span<u8> message) { |
| 523 | ASSERT(flags == 0); | 691 | ASSERT(flags == 0); |
| 524 | ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); | 692 | ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); |
| 525 | 693 | ||
| @@ -532,21 +700,20 @@ std::pair<s32, Errno> Socket::Recv(int flags, std::vector<u8>& message) { | |||
| 532 | return {-1, GetAndLogLastError()}; | 700 | return {-1, GetAndLogLastError()}; |
| 533 | } | 701 | } |
| 534 | 702 | ||
| 535 | std::pair<s32, Errno> Socket::RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) { | 703 | std::pair<s32, Errno> Socket::RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) { |
| 536 | ASSERT(flags == 0); | 704 | ASSERT(flags == 0); |
| 537 | ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); | 705 | ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); |
| 538 | 706 | ||
| 539 | sockaddr addr_in{}; | 707 | sockaddr_in addr_in{}; |
| 540 | socklen_t addrlen = sizeof(addr_in); | 708 | socklen_t addrlen = sizeof(addr_in); |
| 541 | socklen_t* const p_addrlen = addr ? &addrlen : nullptr; | 709 | socklen_t* const p_addrlen = addr ? &addrlen : nullptr; |
| 542 | sockaddr* const p_addr_in = addr ? &addr_in : nullptr; | 710 | sockaddr* const p_addr_in = addr ? reinterpret_cast<sockaddr*>(&addr_in) : nullptr; |
| 543 | 711 | ||
| 544 | const auto result = recvfrom(fd, reinterpret_cast<char*>(message.data()), | 712 | const auto result = recvfrom(fd, reinterpret_cast<char*>(message.data()), |
| 545 | static_cast<int>(message.size()), 0, p_addr_in, p_addrlen); | 713 | static_cast<int>(message.size()), 0, p_addr_in, p_addrlen); |
| 546 | if (result != SOCKET_ERROR) { | 714 | if (result != SOCKET_ERROR) { |
| 547 | if (addr) { | 715 | if (addr) { |
| 548 | ASSERT(addrlen == sizeof(addr_in)); | 716 | *addr = TranslateToSockAddrIn(addr_in, addrlen); |
| 549 | *addr = TranslateToSockAddrIn(addr_in); | ||
| 550 | } | 717 | } |
| 551 | return {static_cast<s32>(result), Errno::SUCCESS}; | 718 | return {static_cast<s32>(result), Errno::SUCCESS}; |
| 552 | } | 719 | } |
| @@ -597,6 +764,11 @@ Errno Socket::Close() { | |||
| 597 | return Errno::SUCCESS; | 764 | return Errno::SUCCESS; |
| 598 | } | 765 | } |
| 599 | 766 | ||
| 767 | std::pair<Errno, Errno> Socket::GetPendingError() { | ||
| 768 | auto [pending_err, getsockopt_err] = GetSockOpt<int>(fd, SO_ERROR); | ||
| 769 | return {TranslateNativeError(pending_err), getsockopt_err}; | ||
| 770 | } | ||
| 771 | |||
| 600 | Errno Socket::SetLinger(bool enable, u32 linger) { | 772 | Errno Socket::SetLinger(bool enable, u32 linger) { |
| 601 | return SetSockOpt(fd, SO_LINGER, MakeLinger(enable, linger)); | 773 | return SetSockOpt(fd, SO_LINGER, MakeLinger(enable, linger)); |
| 602 | } | 774 | } |
diff --git a/src/core/internal_network/network.h b/src/core/internal_network/network.h index 1e09a007a..badcb8369 100644 --- a/src/core/internal_network/network.h +++ b/src/core/internal_network/network.h | |||
| @@ -5,6 +5,7 @@ | |||
| 5 | 5 | ||
| 6 | #include <array> | 6 | #include <array> |
| 7 | #include <optional> | 7 | #include <optional> |
| 8 | #include <vector> | ||
| 8 | 9 | ||
| 9 | #include "common/common_funcs.h" | 10 | #include "common/common_funcs.h" |
| 10 | #include "common/common_types.h" | 11 | #include "common/common_types.h" |
| @@ -16,6 +17,11 @@ | |||
| 16 | #include <netinet/in.h> | 17 | #include <netinet/in.h> |
| 17 | #endif | 18 | #endif |
| 18 | 19 | ||
| 20 | namespace Common { | ||
| 21 | template <typename T, typename E> | ||
| 22 | class Expected; | ||
| 23 | } | ||
| 24 | |||
| 19 | namespace Network { | 25 | namespace Network { |
| 20 | 26 | ||
| 21 | class SocketBase; | 27 | class SocketBase; |
| @@ -36,6 +42,26 @@ enum class Errno { | |||
| 36 | NETUNREACH, | 42 | NETUNREACH, |
| 37 | TIMEDOUT, | 43 | TIMEDOUT, |
| 38 | MSGSIZE, | 44 | MSGSIZE, |
| 45 | INPROGRESS, | ||
| 46 | OTHER, | ||
| 47 | }; | ||
| 48 | |||
| 49 | enum class GetAddrInfoError { | ||
| 50 | SUCCESS, | ||
| 51 | ADDRFAMILY, | ||
| 52 | AGAIN, | ||
| 53 | BADFLAGS, | ||
| 54 | FAIL, | ||
| 55 | FAMILY, | ||
| 56 | MEMORY, | ||
| 57 | NODATA, | ||
| 58 | NONAME, | ||
| 59 | SERVICE, | ||
| 60 | SOCKTYPE, | ||
| 61 | SYSTEM, | ||
| 62 | BADHINTS, | ||
| 63 | PROTOCOL, | ||
| 64 | OVERFLOW_, | ||
| 39 | OTHER, | 65 | OTHER, |
| 40 | }; | 66 | }; |
| 41 | 67 | ||
| @@ -49,6 +75,9 @@ enum class PollEvents : u16 { | |||
| 49 | Err = 1 << 3, | 75 | Err = 1 << 3, |
| 50 | Hup = 1 << 4, | 76 | Hup = 1 << 4, |
| 51 | Nval = 1 << 5, | 77 | Nval = 1 << 5, |
| 78 | RdNorm = 1 << 6, | ||
| 79 | RdBand = 1 << 7, | ||
| 80 | WrBand = 1 << 8, | ||
| 52 | }; | 81 | }; |
| 53 | 82 | ||
| 54 | DECLARE_ENUM_FLAG_OPERATORS(PollEvents); | 83 | DECLARE_ENUM_FLAG_OPERATORS(PollEvents); |
| @@ -82,4 +111,11 @@ constexpr IPv4Address TranslateIPv4(in_addr addr) { | |||
| 82 | /// @return human ordered IPv4 address (e.g. 192.168.0.1) as an array | 111 | /// @return human ordered IPv4 address (e.g. 192.168.0.1) as an array |
| 83 | std::optional<IPv4Address> GetHostIPv4Address(); | 112 | std::optional<IPv4Address> GetHostIPv4Address(); |
| 84 | 113 | ||
| 114 | std::string IPv4AddressToString(IPv4Address ip_addr); | ||
| 115 | u32 IPv4AddressToInteger(IPv4Address ip_addr); | ||
| 116 | |||
| 117 | // named to avoid name collision with Windows macro | ||
| 118 | Common::Expected<std::vector<AddrInfo>, GetAddrInfoError> GetAddressInfo( | ||
| 119 | const std::string& host, const std::optional<std::string>& service); | ||
| 120 | |||
| 85 | } // namespace Network | 121 | } // namespace Network |
diff --git a/src/core/internal_network/socket_proxy.cpp b/src/core/internal_network/socket_proxy.cpp index 7a77171c2..44e9e3093 100644 --- a/src/core/internal_network/socket_proxy.cpp +++ b/src/core/internal_network/socket_proxy.cpp | |||
| @@ -98,7 +98,7 @@ Errno ProxySocket::Shutdown(ShutdownHow how) { | |||
| 98 | return Errno::SUCCESS; | 98 | return Errno::SUCCESS; |
| 99 | } | 99 | } |
| 100 | 100 | ||
| 101 | std::pair<s32, Errno> ProxySocket::Recv(int flags, std::vector<u8>& message) { | 101 | std::pair<s32, Errno> ProxySocket::Recv(int flags, std::span<u8> message) { |
| 102 | LOG_WARNING(Network, "(STUBBED) called"); | 102 | LOG_WARNING(Network, "(STUBBED) called"); |
| 103 | ASSERT(flags == 0); | 103 | ASSERT(flags == 0); |
| 104 | ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); | 104 | ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); |
| @@ -106,7 +106,7 @@ std::pair<s32, Errno> ProxySocket::Recv(int flags, std::vector<u8>& message) { | |||
| 106 | return {static_cast<s32>(0), Errno::SUCCESS}; | 106 | return {static_cast<s32>(0), Errno::SUCCESS}; |
| 107 | } | 107 | } |
| 108 | 108 | ||
| 109 | std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) { | 109 | std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) { |
| 110 | ASSERT(flags == 0); | 110 | ASSERT(flags == 0); |
| 111 | ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); | 111 | ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); |
| 112 | 112 | ||
| @@ -140,8 +140,8 @@ std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::vector<u8>& message, | |||
| 140 | } | 140 | } |
| 141 | } | 141 | } |
| 142 | 142 | ||
| 143 | std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& message, | 143 | std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::span<u8> message, SockAddrIn* addr, |
| 144 | SockAddrIn* addr, std::size_t max_length) { | 144 | std::size_t max_length) { |
| 145 | ProxyPacket& packet = received_packets.front(); | 145 | ProxyPacket& packet = received_packets.front(); |
| 146 | if (addr) { | 146 | if (addr) { |
| 147 | addr->family = Domain::INET; | 147 | addr->family = Domain::INET; |
| @@ -153,10 +153,7 @@ std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& mes | |||
| 153 | std::size_t read_bytes; | 153 | std::size_t read_bytes; |
| 154 | if (packet.data.size() > max_length) { | 154 | if (packet.data.size() > max_length) { |
| 155 | read_bytes = max_length; | 155 | read_bytes = max_length; |
| 156 | message.clear(); | 156 | memcpy(message.data(), packet.data.data(), max_length); |
| 157 | std::copy(packet.data.begin(), packet.data.begin() + read_bytes, | ||
| 158 | std::back_inserter(message)); | ||
| 159 | message.resize(max_length); | ||
| 160 | 157 | ||
| 161 | if (protocol == Protocol::UDP) { | 158 | if (protocol == Protocol::UDP) { |
| 162 | if (!peek) { | 159 | if (!peek) { |
| @@ -171,9 +168,7 @@ std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& mes | |||
| 171 | } | 168 | } |
| 172 | } else { | 169 | } else { |
| 173 | read_bytes = packet.data.size(); | 170 | read_bytes = packet.data.size(); |
| 174 | message.clear(); | 171 | memcpy(message.data(), packet.data.data(), read_bytes); |
| 175 | std::copy(packet.data.begin(), packet.data.end(), std::back_inserter(message)); | ||
| 176 | message.resize(max_length); | ||
| 177 | if (!peek) { | 172 | if (!peek) { |
| 178 | received_packets.pop(); | 173 | received_packets.pop(); |
| 179 | } | 174 | } |
| @@ -293,6 +288,11 @@ Errno ProxySocket::SetNonBlock(bool enable) { | |||
| 293 | return Errno::SUCCESS; | 288 | return Errno::SUCCESS; |
| 294 | } | 289 | } |
| 295 | 290 | ||
| 291 | std::pair<Errno, Errno> ProxySocket::GetPendingError() { | ||
| 292 | LOG_DEBUG(Network, "(STUBBED) called"); | ||
| 293 | return {Errno::SUCCESS, Errno::SUCCESS}; | ||
| 294 | } | ||
| 295 | |||
| 296 | bool ProxySocket::IsOpened() const { | 296 | bool ProxySocket::IsOpened() const { |
| 297 | return fd != INVALID_SOCKET; | 297 | return fd != INVALID_SOCKET; |
| 298 | } | 298 | } |
diff --git a/src/core/internal_network/socket_proxy.h b/src/core/internal_network/socket_proxy.h index 6e991fa38..e12c413d1 100644 --- a/src/core/internal_network/socket_proxy.h +++ b/src/core/internal_network/socket_proxy.h | |||
| @@ -39,11 +39,11 @@ public: | |||
| 39 | 39 | ||
| 40 | Errno Shutdown(ShutdownHow how) override; | 40 | Errno Shutdown(ShutdownHow how) override; |
| 41 | 41 | ||
| 42 | std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) override; | 42 | std::pair<s32, Errno> Recv(int flags, std::span<u8> message) override; |
| 43 | 43 | ||
| 44 | std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) override; | 44 | std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) override; |
| 45 | 45 | ||
| 46 | std::pair<s32, Errno> ReceivePacket(int flags, std::vector<u8>& message, SockAddrIn* addr, | 46 | std::pair<s32, Errno> ReceivePacket(int flags, std::span<u8> message, SockAddrIn* addr, |
| 47 | std::size_t max_length); | 47 | std::size_t max_length); |
| 48 | 48 | ||
| 49 | std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override; | 49 | std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override; |
| @@ -74,6 +74,8 @@ public: | |||
| 74 | template <typename T> | 74 | template <typename T> |
| 75 | Errno SetSockOpt(SOCKET fd, int option, T value); | 75 | Errno SetSockOpt(SOCKET fd, int option, T value); |
| 76 | 76 | ||
| 77 | std::pair<Errno, Errno> GetPendingError() override; | ||
| 78 | |||
| 77 | bool IsOpened() const override; | 79 | bool IsOpened() const override; |
| 78 | 80 | ||
| 79 | private: | 81 | private: |
diff --git a/src/core/internal_network/sockets.h b/src/core/internal_network/sockets.h index 11e479e50..46a53ef79 100644 --- a/src/core/internal_network/sockets.h +++ b/src/core/internal_network/sockets.h | |||
| @@ -59,10 +59,9 @@ public: | |||
| 59 | 59 | ||
| 60 | virtual Errno Shutdown(ShutdownHow how) = 0; | 60 | virtual Errno Shutdown(ShutdownHow how) = 0; |
| 61 | 61 | ||
| 62 | virtual std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) = 0; | 62 | virtual std::pair<s32, Errno> Recv(int flags, std::span<u8> message) = 0; |
| 63 | 63 | ||
| 64 | virtual std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, | 64 | virtual std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) = 0; |
| 65 | SockAddrIn* addr) = 0; | ||
| 66 | 65 | ||
| 67 | virtual std::pair<s32, Errno> Send(std::span<const u8> message, int flags) = 0; | 66 | virtual std::pair<s32, Errno> Send(std::span<const u8> message, int flags) = 0; |
| 68 | 67 | ||
| @@ -87,6 +86,8 @@ public: | |||
| 87 | 86 | ||
| 88 | virtual Errno SetNonBlock(bool enable) = 0; | 87 | virtual Errno SetNonBlock(bool enable) = 0; |
| 89 | 88 | ||
| 89 | virtual std::pair<Errno, Errno> GetPendingError() = 0; | ||
| 90 | |||
| 90 | virtual bool IsOpened() const = 0; | 91 | virtual bool IsOpened() const = 0; |
| 91 | 92 | ||
| 92 | virtual void HandleProxyPacket(const ProxyPacket& packet) = 0; | 93 | virtual void HandleProxyPacket(const ProxyPacket& packet) = 0; |
| @@ -126,9 +127,9 @@ public: | |||
| 126 | 127 | ||
| 127 | Errno Shutdown(ShutdownHow how) override; | 128 | Errno Shutdown(ShutdownHow how) override; |
| 128 | 129 | ||
| 129 | std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) override; | 130 | std::pair<s32, Errno> Recv(int flags, std::span<u8> message) override; |
| 130 | 131 | ||
| 131 | std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) override; | 132 | std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) override; |
| 132 | 133 | ||
| 133 | std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override; | 134 | std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override; |
| 134 | 135 | ||
| @@ -156,6 +157,11 @@ public: | |||
| 156 | template <typename T> | 157 | template <typename T> |
| 157 | Errno SetSockOpt(SOCKET fd, int option, T value); | 158 | Errno SetSockOpt(SOCKET fd, int option, T value); |
| 158 | 159 | ||
| 160 | std::pair<Errno, Errno> GetPendingError() override; | ||
| 161 | |||
| 162 | template <typename T> | ||
| 163 | std::pair<T, Errno> GetSockOpt(SOCKET fd, int option); | ||
| 164 | |||
| 159 | bool IsOpened() const override; | 165 | bool IsOpened() const override; |
| 160 | 166 | ||
| 161 | void HandleProxyPacket(const ProxyPacket& packet) override; | 167 | void HandleProxyPacket(const ProxyPacket& packet) override; |